Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,15 @@ def skip_forward(
if hasattr(module, 'skip_forward'):
module.forward = module.skip_forward
remove_weights(module, ignore_modules)
else:
logger.warning(
f"Fail to skip forward since {module.__class__.__name__} "
f"does not have `skip_forward`.")


def forward_after_recv(forward_fn):
if hasattr(forward_fn, "__wrapped_by_forward_after_recv__"):
return forward_fn

def forward_after_recv_fn(
position_ids,
Expand All @@ -176,10 +182,13 @@ def forward_after_recv_fn(
**kwargs,
)

forward_after_recv_fn.__wrapped_by_forward_after_recv__ = True
return forward_after_recv_fn


def forward_before_send(forward_fn):
if hasattr(forward_fn, "__wrapped_by_forward_before_send__"):
return forward_fn

def forward_before_send_fn(
position_ids,
Expand All @@ -204,6 +213,7 @@ def forward_before_send_fn(
pp_send(hidden_states)
return output

forward_before_send_fn.__wrapped_by_forward_before_send__ = True
return forward_before_send_fn


Expand Down Expand Up @@ -411,6 +421,8 @@ def __pp_init__(self):
for module in self.epilogue:
skip_forward(module)

self.model.__pp_init__()

def __post_init__(self):
# 1. mixed precision
quant_config_dict = self.model_config.quant_config_dict
Expand Down