Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 46 additions & 122 deletions colossalai/autochunk/autochunk_codegen.py

Large diffs are not rendered by default.

92 changes: 25 additions & 67 deletions colossalai/autochunk/trace_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


class TraceFlow(object):

def __init__(self, trace_indice: TraceIndice) -> None:
self.trace_indice = trace_indice

Expand All @@ -28,9 +29,7 @@ def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace["source"][end_dim]
sorted_source = sorted(
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
)
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
for node_idx, node_dim in sorted_source:
if node_idx == start_node_idx and start_dim in node_dim:
return True
Expand Down Expand Up @@ -70,26 +69,20 @@ def _find_inherit_dim(self, input_node, input_dim, node):
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))):
if (
input_node_idx in node_trace_source[node_dim]
and input_dim[0] in node_trace_source[node_dim][input_node_idx]
):
if (input_node_idx in node_trace_source[node_dim]
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
return node_dim
return None

def check_index_duplicate(self, chunk_infos, return_dim=False):
input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim(
input_node, v, self.trace_indice.node_list[k]
)
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
if inherit_dim:
input_dim_after_node[k] = inherit_dim

for node in self.trace_indice.node_list[
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
]:
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
if is_non_compute_node_except_placeholder(node):
continue
count = 0
Expand Down Expand Up @@ -159,9 +152,7 @@ def _assgin_single_node_flow(
if arg_node in all_node_info:
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
return False
all_node_info[arg_node]["fix_dim"] = list(
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
)
all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
# else add it to list
else:
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
Expand All @@ -170,9 +161,7 @@ def _assgin_single_node_flow(
return True

def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [
self.trace_indice.node_list[end_idx]
] # start from the last node
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}

while len(cur_node_list) > 0:
Expand All @@ -183,12 +172,8 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
if cur_node_chunk_dim:
cur_node_compute = self.trace_indice._find_compute_trace_from_node(
cur_node
)
cur_node_source = self.trace_indice._find_source_trace_from_node(
cur_node
)
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
else:
cur_node_compute = cur_node_source = None

Expand All @@ -215,15 +200,9 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx):
return None

if len(arg_list) == 2:
if any(i in cur_node.name for i in ["add", "mul"]):
if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
for arg in arg_list:
if not (
start_idx
<= find_idx_by_name(
arg.name, self.trace_indice.node_list
)
< end_idx
):
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
arg_fix_dim = all_node_info[arg]["fix_dim"]
Expand All @@ -249,19 +228,15 @@ def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
remove_inputs = []
for input_node in inputs:
input_dict = {}
input_node_idx = find_idx_by_name(
input_node.name, self.trace_indice.node_list
)
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
for user in input_node.users.keys():
if is_non_compute_node(user):
continue
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None:
user_source = self.trace_indice._find_source_trace_from_node(
user
)[chunk_dim]
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
if input_node_idx in user_source:
input_dict[user_idx] = user_source[input_node_idx]
else:
Expand All @@ -284,7 +259,7 @@ def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
maybe_prepose_nodes.sort(
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
reverse=True,
) # from last node to first node
) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
Expand All @@ -305,13 +280,8 @@ def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue
# out of loop
if not (
start_idx
<= find_idx_by_name(
cur_prepose_node_arg.name, self.trace_indice.node_list
)
< end_idx
):
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
end_idx):
continue
# compute op in loop
elif cur_prepose_node_arg in all_node_info:
Expand All @@ -335,15 +305,13 @@ def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n)
# sort by index
prepose_nodes.sort(
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)
)
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))

return prepose_nodes

def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1]
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n)
Expand All @@ -354,9 +322,7 @@ def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
return chunk_info

def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.trace_indice.node_list[start_idx : end_idx + 1]
)
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
# only single ouput
if len(outputs) > 1:
return None
Expand All @@ -367,9 +333,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim):
return None

# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(
inputs, start_idx, end_idx, all_node_info
)
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None

Expand All @@ -385,9 +349,7 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim):
}

# move useless nodes ahead of loop
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(
all_node_info, start_idx, end_idx
)
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)

# find non chunk inputs
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
Expand All @@ -400,10 +362,8 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim):
def _reassgin_reshape_size(self, chunk_info):
chunk_region = chunk_info["region"]
reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
chunk_info["outputs_dim"]
]
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]):
reshape_args = node.args[1:]
reshape_log = self.trace_indice.indice_view_list[node]
Expand All @@ -413,8 +373,6 @@ def _reassgin_reshape_size(self, chunk_info):
if reshape_arg_dim in reshape_log["dim_to"]:
continue
if reshape_arg_dim == chunk_dim:
reshape_size[node.name][reshape_arg.name] = (
"min(chunk_size, %d - chunk_idx)" % chunk_shape
)
reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape)
chunk_info["reshape_size"] = reshape_size
return chunk_info
62 changes: 33 additions & 29 deletions colossalai/autochunk/trace_indice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch.fx.node import Node

from .utils import find_idx_by_name, get_node_shape
from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list


class TraceIndice(object):
Expand Down Expand Up @@ -79,9 +79,7 @@ def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim):
node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to)
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy(
node_from_trace["compute"][node_from_dim]
)
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)

def _inherit_all_computation(self, node_from, node_to):
Expand Down Expand Up @@ -209,7 +207,7 @@ def _assign_indice_as_input(self, node, node_idx, input_node=None):
node_idx (int)
"""
if input_node == None:
input_node = node.args[0]
input_node = find_first_tensor_arg(node)
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]

Expand All @@ -227,6 +225,8 @@ def _assign_all_indice(self, node, node_idx):
node_idx (int)
"""
shape = node.meta["tensor_meta"].shape
if shape is None:
return
new_trace = []
for _ in shape:
new_trace.append(self._add_indice())
Expand Down Expand Up @@ -259,7 +259,7 @@ def _assign_permute_indice(self, node, node_idx):
node (node)
node_idx (int)
"""
permute_dim = node.args[1:]
permute_dim = unflat_list(node.args[1:])
input_node = node.args[0]

self._assign_indice_as_input(node, node_idx, input_node)
Expand Down Expand Up @@ -359,6 +359,15 @@ def _assign_einsum_indice(self, node, idx):
left, right = patterns.split("->")
left = left.split(",")

if '...' in right:
replace_list = "!@#$%^&*"
target_len = len(get_node_shape(node))
add_len = target_len - len(right) + 3
replace_str = replace_list[:add_len]
right = right.replace("...", replace_str)
for ll in range(len(left)):
left[ll] = left[ll].replace("...", replace_str)

all_index = []
for i in left:
for c in i:
Expand All @@ -369,9 +378,7 @@ def _assign_einsum_indice(self, node, idx):
for left_idx, left_str in enumerate(left):
if right_indice in left_str:
source_idx = left_str.index(right_indice)
self._inherit_indice(
input_nodes[left_idx], source_idx, node, right_idx
)
self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx)

def _assign_softmax_indice(self, node, idx):
"""
Expand Down Expand Up @@ -440,11 +447,12 @@ def _assign_view_reshape_indice(self, node, node_idx):
origin_node = node.args[0]
origin_shape = origin_node.meta["tensor_meta"].shape
target_shape = []
for i in range(1, len(node.args)):
if isinstance(node.args[i], int):
target_shape.append(node.args[i])
unflated_args = unflat_list(node.args)
for i in range(1, len(unflated_args)):
if isinstance(unflated_args[i], int):
target_shape.append(unflated_args[i])
else:
target_shape.append(node.args[i].meta["fwd_out"][0])
target_shape.append(unflated_args[i].meta["fwd_out"][0])

# compute the value of -1
if -1 in target_shape:
Expand Down Expand Up @@ -472,13 +480,7 @@ def _assign_view_reshape_indice(self, node, node_idx):
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
self._del_dim(node_idx, -1)
else:
raise NotImplementedError(
"shape"
+ str(origin_shape)
+ "and"
+ str(target_shape)
+ "view not implemented"
)
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")

# get new indice
origin_trace = self._find_indice_trace_from_node(origin_node)
Expand Down Expand Up @@ -521,6 +523,8 @@ def trace_indice(self):
self._assign_unsqueeze_indice(node, idx)
elif any(i in node.name for i in ["to", "contiguous"]):
self._assgin_no_change_indice(node, idx)
elif "new_ones" in node.name:
self._assign_ones_like_indice(node, idx)
else:
raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == "call_function":
Expand All @@ -530,29 +534,29 @@ def trace_indice(self):
self._assign_matmul_indice(node, idx)
elif "softmax" in node.name:
self._assign_softmax_indice(node, idx)
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]):
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]):
self._assign_elementwise_indice(node, idx)
elif "ones_like" in node.name:
self._assign_ones_like_indice(node, idx)
elif "dropout" in node.name:
self._assign_dropout_indice(node, idx)
elif "einsum" in node.name:
self._assign_einsum_indice(node, idx)
elif "getattr" in node.name:
continue # get attr like shape
elif "getitem" in node.name:
continue # get item in list
elif "layer_norm" in node.name:
self._assign_layernorm_indice(node, idx)
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
continue
else:
raise NotImplementedError(
node.name, "function not implemented yet!"
)
raise NotImplementedError(node.name, "function not implemented yet!")
elif node.op == "call_module":
if any(n in node.name for n in ["layernorm", "norm"]):
self._assign_layernorm_indice(node, idx)
elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]):
self._assign_elementwise_indice(node, idx)
else:
raise NotImplementedError(node.name, "module not implemented yet!")
elif node.op == "get_attr":
self._assign_all_indice(node, idx) # get param
self._assign_all_indice(node, idx) # get param
elif node.op == "output":
continue
else:
Expand Down
Loading