Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions colossalai/autochunk/estimate_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@

from colossalai.fx.profiler import activation_size, parameter_size

from .utils import (
delete_free_var_from_last_use,
find_idx_by_name,
get_node_shape,
is_non_compute_node_except_placeholder,
)
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node


class EstimateMemory(object):
Expand Down Expand Up @@ -240,7 +235,7 @@ def estimate_chunk_inference_mem(
elif node.op == "output":
continue
# no change for non compute node
elif is_non_compute_node_except_placeholder(node):
elif is_non_memory_node(node):
act_memory_peak_log.append(act_memory)
# node is a compute op
# calculate tmp, output node and delete node memory
Expand Down
43 changes: 32 additions & 11 deletions colossalai/autochunk/trace_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,34 @@ def check_index_duplicate(self, chunk_infos, return_dim=False):

def _assgin_single_node_flow(
self,
arg_node,
start_idx,
end_idx,
cur_node_dim,
cur_node_compute,
cur_node_source,
cur_node_fix_dim,
all_node_info,
next_node_list,
):
arg_node: Node,
start_idx: int,
end_idx: int,
cur_node_dim: int,
cur_node_compute: Dict,
cur_node_source: Dict,
cur_node_fix_dim: List,
all_node_info: Dict,
next_node_list: List,
) -> bool:
"""
Given the current node and one of its arg node,
this function finds out arg node's chunk dim and fix dim

Args:
arg_node (Node): input node
start_idx (int): chunk region start
end_idx (int): chunk region end
cur_node_dim (int): current node chunk dim
cur_node_compute (Dict): current node compute dict
cur_node_source (Dict): current node source dict
cur_node_fix_dim (List): current node fix dim
all_node_info (Dict): all node chunk info in the chunk region
next_node_list (List)

Returns:
bool: True if this node can be added to the flow, vice versa.
"""
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
# arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx):
Expand All @@ -142,6 +160,9 @@ def _assgin_single_node_flow(
arg_dim = None
else:
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
# chunk dim should be None if shape size is 1
if get_node_shape(arg_node)[arg_dim] == 1:
arg_dim = None
else:
arg_dim = None

Expand Down Expand Up @@ -184,7 +205,7 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx):

# get all valid args
arg_list = []
for arg in cur_node.args:
for arg in cur_node.all_input_nodes:
if type(arg) != type(cur_node):
continue
if is_non_compute_node(arg):
Expand Down
56 changes: 52 additions & 4 deletions colossalai/autochunk/trace_indice.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,38 @@ def _assign_ones_like_indice(self, node: Node, node_idx: int):
"""
self._assign_all_indice(node, node_idx)

def _assign_cat_indice(self, node: Node, node_idx: int):
"""
Assign indice for cat op.

Args:
node (node)
node_idx (int)
"""
nodes_in = flat_list(node.args[0])
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._mark_computation_from_node(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
self._add_dim(node_idx, cat_dim)

def _assign_sum_indice(self, node: Node, node_idx: int):
"""
Assign indice for sum op.

Args:
node (node)
node_idx (int)
"""
nodes_in = flat_list(node.args[0])
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._mark_computation_from_node(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)

def _assign_getitem_indice(self, node: Node, node_idx: int):
"""
Assign indice for getitem.
Expand All @@ -442,7 +474,16 @@ def _assign_getitem_indice(self, node: Node, node_idx: int):
node_idx (int)
"""
node_args = flat_list(node.args[1:])
if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args):
flag = False
for node_arg in node_args:
node_arg_str = str(node_arg)
if any(i == node_arg_str for i in ["None", "Ellipsis"]):
flag = True
break
if "slice" in node_arg_str:
flag = True
break
if flag == False:
return

# node args should be like [Ellipsis, slice(start, step, end), None]
Expand All @@ -461,8 +502,11 @@ def _assign_getitem_indice(self, node: Node, node_idx: int):
shape_gap = len(node_shape) - len(node_args) + 1
origin_idx_count += shape_gap
new_idx_count += shape_gap
# slice(None, None, None) means all indexes, doesn't support other slice
elif "slice(None, None, None)" == node_arg_str:
# slice(None, None, None) means all indexes
elif "slice" in node_arg_str:
if "slice(None, None, None)" != node_arg_str:
self._del_dim(node_idx, new_idx_count)
self._add_dim(node_idx, new_idx_count)
origin_idx_count += 1
new_idx_count += 1
# None means a new dim
Expand Down Expand Up @@ -565,7 +609,7 @@ def trace_indice(self):
self._assign_view_reshape_indice(node, idx)
elif "unsqueeze" in node.name:
self._assign_unsqueeze_indice(node, idx)
elif any(i in node.name for i in ["to", "contiguous"]):
elif any(i in node.name for i in ["to", "contiguous", "clone"]):
self._assgin_no_change_indice(node, idx)
elif "new_ones" in node.name:
self._assign_ones_like_indice(node, idx)
Expand All @@ -574,6 +618,8 @@ def trace_indice(self):
elif node.op == "call_function":
if "linear" in node.name:
self._assign_linear_indice(node, idx)
elif "cat" in node.name:
self._assign_cat_indice(node, idx)
elif "matmul" in node.name:
self._assign_matmul_indice(node, idx)
elif "softmax" in node.name:
Expand All @@ -586,6 +632,8 @@ def trace_indice(self):
self._assign_dropout_indice(node, idx)
elif "einsum" in node.name:
self._assign_einsum_indice(node, idx)
elif "sum" in node.name:
self._assign_sum_indice(node, idx)
elif "layer_norm" in node.name:
self._assign_layernorm_indice(node, idx)
elif "getitem" in node.name:
Expand Down
20 changes: 16 additions & 4 deletions colossalai/autochunk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from torch.fx.node import Node


def flat_list(inputs):
def flat_list(inputs: Any) -> List:
"""
flat a list by recursion
"""
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
return [inputs]
res = []
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
Expand All @@ -16,7 +18,7 @@ def flat_list(inputs):
return res


def find_first_tensor_arg(node):
def find_first_tensor_arg(node: Node) -> Node:
"""
Find the first input tensor arg for a node
"""
Expand All @@ -26,24 +28,34 @@ def find_first_tensor_arg(node):
raise RuntimeError()


def is_non_compute_node(node):
def is_non_compute_node(node: Node) -> bool:
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
return True
if "getitem" in node.name:
node_args = flat_list(node.args[1:])
for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
return False
if "slice" in str(node_arg):
return False
return True
return False


def get_node_shape(node):
def get_node_shape(node: Node) -> List:
if hasattr(node.meta["tensor_meta"], "shape"):
return node.meta["tensor_meta"].shape
return None


def is_non_memory_node(node: Node) -> bool:
if "getitem" in node.name:
return True
if "output" in node.op:
return True
return is_non_compute_node(node)


def is_non_compute_node_except_placeholder(node):
if "placeholder" in node.op:
return False
Expand Down
2 changes: 1 addition & 1 deletion tests/test_autochunk/test_evoformer_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()

# assert we have inserted chunk
Expand Down
Loading