Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,15 @@ transforms:
stage: post_export
cleanup_input_constraints:
stage: post_export
quantize:
stage: pattern_matcher
quantize_moe:
stage: pattern_matcher
match_repeat_kv:
stage: pattern_matcher
match_eager_attention:
stage: pattern_matcher
match_grouped_attention:
stage: pattern_matcher
match_attention_layout:
stage: pattern_matcher
149 changes: 126 additions & 23 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,28 @@
import torch.nn as nn
import torch.nn.functional as F

# TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention.

def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor:
"""Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)"""
if logit_cap is not None and logit_cap > 0.0:
return logit_cap * torch.tanh(attn_scores / logit_cap)
return attn_scores


def _convert_boolean_mask_to_float(attn_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Convert boolean attention mask to floating point mask.
Args:
attn_mask: Boolean tensor where True allows attention, False blocks it
dtype: Target dtype for the output mask
Returns:
Floating point mask where True -> 1.0, False -> -inf
"""
if attn_mask.dtype == torch.bool:
float_mask = torch.zeros_like(attn_mask, dtype=dtype)
float_mask = float_mask.masked_fill(attn_mask, 1.0) # True -> 1.0
float_mask = float_mask.masked_fill(~attn_mask, float("-inf")) # False -> -inf
return float_mask
return attn_mask


@torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=())
Expand Down Expand Up @@ -77,19 +98,96 @@ def grouped_sdpa(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
sinks: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
logit_cap: Optional[float] = None,
) -> torch.Tensor:
"""SDPA attention that can handle GQA."""
"""SDPA attention that can handle GQA. Expects bnsd format inputs."""
b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim]
_, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]

# Inputs are already in bnsd format, no need to transpose
query_t = query # [b, n_heads, s_q, head_dim]
key_t = key # [b, n_kv_heads, s_k, head_dim]
value_t = value # [b, n_kv_heads, s_k, v_head_dim]

# Handle GQA by repeating KV if needed
if n_heads != n_kv_heads:
n_rep = n_heads // n_kv_heads
key_t = repeat_kv(key_t, n_rep)
value_t = repeat_kv(value_t, n_rep)

# Set scale
if scale is None:
scale = 1.0 / math.sqrt(head_dim)

# Compute attention scores: Q @ K^T
attn_scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * scale # [b, n_heads, s_q, s_k]

# Apply attention mask if provided
if attn_mask is not None:
# Convert boolean mask to float if needed
attn_mask = _convert_boolean_mask_to_float(attn_mask, attn_scores.dtype)
attn_scores = attn_scores + attn_mask

# Apply causal mask if specified and only during the context phase
if is_causal and s_q == s_k: # Only apply causal mask during context processing
causal_mask = torch.triu(
torch.ones(s_q, s_k, device=query.device, dtype=torch.bool),
diagonal=1, # Use diagonal=1 for standard causal masking
)
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))

# Apply sliding window mask if specified
if sliding_window is not None and sliding_window > 0:
# Handle position calculation for both context and generation phases
if s_q == s_k:
# Context phase: standard position calculation
query_positions = torch.arange(s_q, device=query.device)
key_positions = torch.arange(s_k, device=query.device)
else:
# Generation phase: query is at position s_k (after the cache)
query_positions = torch.arange(s_k, s_k + s_q, device=query.device) # [s_k] for s_q=1
key_positions = torch.arange(s_k, device=query.device) # [0,1,2,...,s_k-1]

# Create position difference matrix: query_pos - key_pos
pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) # [s_q, s_k]

# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) # [s_q, s_k]
attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf"))

# Apply logit softcapping if enabled
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)

# Apply sinks if provided
if sinks is not None:
# Concatenate sinks to attention scores following the reference implementation
# sinks should have n_heads elements, each head gets its own sink value
# Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
sinks_expanded = sinks.reshape(1, -1, 1, 1).expand(
b, n_heads, s_q, 1
) # [b, n_heads, s_q, 1]

# Concatenate along the key dimension (last dimension)
logits_max = torch.max(attn_scores, dim=-1, keepdim=True).values
sinks = torch.exp(sinks_expanded - logits_max)
unnormalized_scores = torch.exp(attn_scores - logits_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
# Use only the non-sink portion for computing output
# We added exactly 1 column, so remove exactly 1 column
attn_out = torch.matmul(scores, value_t) # [b, n_heads, s_q, v_head_dim]
else:
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype)
attn_out = torch.matmul(attn_weights, value_t) # [b, n_heads, s_q, v_head_dim]

return F.scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=True,
)
# Apply dropout if specified
if dropout_p > 0.0:
attn_out = F.dropout(attn_out, p=dropout_p, training=False)

# Return in bnsd format (same as input format)
return attn_out


@grouped_sdpa.register_fake
Expand All @@ -101,16 +199,19 @@ def grouped_sdpa_fake(
dropout_p=0.0,
is_causal=False,
scale=None,
sinks=None,
sliding_window=None,
logit_cap=None,
):
"""Fake implementation of grouped SDPA."""
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()


@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
def bsnd_grouped_sdpa(
query: torch.Tensor, # layout: [b, n, s_q, d]
key: torch.Tensor, # layout: [b, n, s_k, d]
value: torch.Tensor, # layout: [b, n, s_k, d]
query: torch.Tensor, # layout: [b, s_q, n, d]
key: torch.Tensor, # layout: [b, s_k, n, d]
value: torch.Tensor, # layout: [b, s_k, n, d]
attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k]
dropout_p: float = 0.0,
is_causal: bool = False,
Expand All @@ -124,14 +225,16 @@ def bsnd_grouped_sdpa(
Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
original sdpa op!
"""
# let's transpose to bnsd so we can use the grouped sdpa
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()

out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)

# let's transpose back to bnsd
# Transpose inputs to bnsd format for grouped_sdpa
query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d]
key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]

# Call grouped_sdpa with bnsd inputs
out = grouped_sdpa(
query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap
)
# Transpose back to bsnd format
return out.transpose(1, 2).contiguous()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _torch_generate_mha(
# Apply sinks if provided (following the model file pattern)
if sinks is not None:
# Concatenate sinks to attention scores
sinks = sinks.reshape(-1, 1, 1).expand(-1, attn_scores.shape[-2], -1)
sinks = sinks.reshape(-1, 1, 1)
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
# Use only the non-sink portion for computing output (ignore sinks)
Expand Down Expand Up @@ -202,9 +202,7 @@ def _torch_context_mha(
) # [seq_len_i, kv_seq_len]

# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
sliding_window_mask = (pos_diff < 0) | (
pos_diff >= sliding_window_size
) # [seq_len_i, kv_seq_len]
sliding_window_mask = pos_diff >= sliding_window_size

# Combine causal and sliding window masks
combined_mask = causal_mask | sliding_window_mask
Expand All @@ -219,14 +217,14 @@ def _torch_context_mha(
# Apply sinks if provided (following the model file pattern)
if sinks is not None:
# Concatenate sinks to attention scores
sinks = sinks.reshape(1, -1, 1, 1).expand(
attn_scores.shape[0], -1, attn_scores.shape[-2], -1
new_sinks = sinks.reshape(1, -1, 1, 1).expand(
attn_scores.shape[0], -1, attn_scores.shape[2], 1
)
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
attn_weights = torch.cat([attn_scores, new_sinks], dim=-1)
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
# Use only the non-sink portion for computing output (ignore sinks)
attn_out = torch.matmul(
attn_weights[..., : -sinks.size(-1)], v_seq_t
attn_weights[..., : -new_sinks.size(-1)], v_seq_t
) # [1, n_heads, seq_len_i, v_head_dim]
else:
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
rank, world_size = get_rank_world_size()
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.AUTO)
# Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.NCCL)
return torch_op(tensor, all_reduce_params=all_reduce_params)

@torch.library.custom_op(
Expand Down
27 changes: 23 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/models/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,24 @@ class AutoModelForCausalLMFactory(ModelFactory):
"max_position_embeddings": 1024,
}

def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
"""Get the max position embeddings config for the model."""
return {
"max_position_embeddings": self.max_seq_len,
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._quant_config: Optional[Dict] = None

# Ingest defaults for tokenizer and model kwargs
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
self.model_kwargs = deep_merge_dicts(self._model_defaults, self.model_kwargs)
self.model_kwargs = deep_merge_dicts(
self._model_defaults,
self.model_kwargs,
self._get_max_position_embeddings_config(),
)

# special handling for torch_dtype in model_kwargs since HF does not correctly update
# torch_dtype string to an actual torch.dtype object (only with default)
Expand Down Expand Up @@ -295,7 +305,7 @@ def _prefetch_checkpoint(self, model_name_or_path: str, skip_prefetch_weights: b
# at this point it should be a directory (either the original one or the download dir)
assert os.path.isdir(fetched_dir), f"Checkpoint path {fetched_dir} is not a directory."

self._load_quantization_config()
self._load_quantization_config(fetched_dir)

return fetched_dir

Expand All @@ -313,13 +323,13 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
# model-transformed weights,leading to unexpected key mismatches or format issues.
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)

def _load_quantization_config(self):
def _load_quantization_config(self, fetched_dir: str):
"""Load the quantization config from the model directory if not done already."""
if self._quant_config is not None:
return

assert self.model
hf_quant_config_file = os.path.join(self.model, "hf_quant_config.json")
hf_quant_config_file = os.path.join(fetched_dir, "hf_quant_config.json")
if os.path.exists(hf_quant_config_file):
with open(hf_quant_config_file, "r") as file:
quantization_config = json.load(file)
Expand All @@ -344,6 +354,15 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
},
}

def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
"""Get the max position embeddings config for the model."""
return {
"max_position_embeddings": self.max_seq_len,
"text_config": {
"max_position_embeddings": self.max_seq_len,
},
}

@property
def automodel_from_config(self):
return AutoModelForImageTextToText.from_config
52 changes: 38 additions & 14 deletions tensorrt_llm/_torch/auto_deploy/transform/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,26 @@ def __call__(
# run or skip the transform
if self.config.enabled:
# run graph pre-cleanup
self._run_pre_cleanup(gm, info_last)

# run the transform in a error-handling wrapper
try:
gm, info = self._apply(gm, cm, factory)
except Exception as e:
error_msg = f"Transform {t_name} failed"
if self.config.skip_on_error:
is_clean_pre, has_valid_shapes_pre = self._run_pre_cleanup(gm, info_last)

# run the transform in a error-handling wrapper if desired
if self.config.skip_on_error:
try:
gm, info = self._apply(gm, cm, factory)
except Exception as e:
error_msg = f"Transform {t_name} failed"
ad_logger.warning(f"{error_msg}: {e}")
info = TransformInfo(skipped=True, num_matches=0)
else:
raise TransformError(error_msg) from e
else:
# handle this here normally to improve debugging and error message
gm, info = self._apply(gm, cm, factory)

# we cannot say it's clean if the previous wasn't clean even if this one is
# create new info object with updated cleanup status
info_dict = info.model_dump()
info_dict["is_clean"] &= is_clean_pre
info_dict["has_valid_shapes"] &= has_valid_shapes_pre
info = TransformInfo(**info_dict)

# run graph post-cleanup
info = self._run_post_cleanup(gm, info)
Expand Down Expand Up @@ -279,20 +287,36 @@ def _set_autodeploy_meta(self, gm: GraphModule, autodeploy_meta: AutodeployMeta)
gm.meta[self._autodeploy_meta_key] = autodeploy_meta

@final
def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> None:
def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> Tuple[bool, bool]:
"""Run graph cleanup before the transform.

Args:
gm: The graph module to run cleanup on.
info: The last transform info.

Returns:
A tuple of (is_clean, has_valid_shapes) indicating the cleanup status after the
pre-cleanup.

This is used to ensure the transform is applied to a clean graph as needed by the transform.
"""
if not self.config.requires_clean_graph:
return
return info.is_clean, info.has_valid_shapes

is_clean = info.is_clean
has_valid_shapes = is_clean and info.has_valid_shapes

# check if run cleanup depending on the config and info
if self.config.requires_shape_prop and not (info.is_clean and info.has_valid_shapes):
if self.config.requires_shape_prop and not has_valid_shapes:
with lift_to_meta(gm):
canonicalize_graph(gm, shape_prop=True)
elif self.config.requires_clean_graph and not info.is_clean:
is_clean = True
has_valid_shapes = True
elif self.config.requires_clean_graph and not is_clean:
canonicalize_graph(gm)
is_clean = True

return is_clean, has_valid_shapes

@final
def _run_post_cleanup(self, gm: GraphModule, info: TransformInfo) -> TransformInfo:
Expand Down
Loading