Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def detect_sharding_from_factory_config(
num_simple_shards = 0
num_row_col_shards = 0

for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op):
for lin_node in filtered_nodes(gm.graph.nodes, [is_linear_op, is_fake_quantized_linear_op]):
# use node's weight name to get the module name
module_name = lin_node.args[1].target

Expand Down Expand Up @@ -368,7 +368,7 @@ def detect_sharding_from_factory_config(
)
num_row_col_shards += 1
else:
ad_logger.warning("Invalid sharding config. Skipping.")
ad_logger.warning(f"Unsupported sharding action {config}. Skipping.")
else:
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
Expand All @@ -387,7 +387,19 @@ def detect_sharding_from_factory_config(
)
num_simple_shards += 1
else:
ad_logger.warning("Invalid sharding config. Skipping.")
ad_logger.warning(
f"Unsupported sharding action {config}. Fallback to simple shard"
)
sharding_config.tp_transforms.append(
TPShardingInfo.from_node(
lin_node,
split_dim=SplitDimension.COLUMN,
rank=rank,
world_size=world_size,
dist_op="all_gather",
min_local_shape=1,
)
)
# after successful match, break the loop
break

Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ def filtered_nodes(
for node in nodes:
if target(node):
yield node
elif isinstance(target, Iterable) and all(isinstance(t, Callable) for t in target):
for node in nodes:
for t in target:
if t(node):
yield node
break
else:
# Handle the case where target or ops contains operations
operations = ops if ops is not None else target
Expand Down