Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions colossalai/auto_parallel/passes/runtime_preparation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:

def wrapper(param, comm_spec):
def wrapper(param, comm_spec, stream):

def hook_fn(grad):
_all_reduce(grad, comm_spec, async_op=False)
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)

param.register_hook(hook_fn)

wrapper(param, comm_spec_to_use)
wrapper(param, comm_spec_to_use, reduction_stream)

sharded_buffer_dict = {}
# apply the sharding spec of buffers
Expand Down Expand Up @@ -440,14 +441,15 @@ def hook_fn(grad):
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:

def wrapper(param, comm_spec):
def wrapper(param, comm_spec, stream):

def hook_fn(grad):
_all_reduce(grad, comm_spec, async_op=False)
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)

param.register_hook(hook_fn)

wrapper(target, comm_spec_to_use)
wrapper(target, comm_spec_to_use, reduction_stream)
return gm


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,4 +483,6 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
raise TypeError(
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
strategies = recovered_stragies
for index, strategies in enumerate(strategies):
strategies.name = f"{strategies.name}_{index}"
return strategies
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ def collate_strategies(self) -> List[ShardingStrategy]:
strategies.append(self.split_rhs_space_both_contract(1, 0))

# RR= RS x SR
strategies.append(self.recompute_split_both_contract(0))
strategies.append(self.recompute_split_both_contract(1))
# strategies.append(self.recompute_split_both_contract(0))
# strategies.append(self.recompute_split_both_contract(1))

# RS = RR x RS
strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1))
# # RS = RR x RS
# strategies.append(self.split_rhs_space_only(0))
# strategies.append(self.split_rhs_space_only(1))

# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
Expand All @@ -263,8 +263,8 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))

# RR = RR x RR
strategies.append(self.non_split())
# # RR = RR x RR
# strategies.append(self.non_split())

return strategies

Expand Down
2 changes: 1 addition & 1 deletion colossalai/device/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def flatten(self):
return DeviceMesh(self.physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self.init_process_group,
need_flatten=False)

Expand Down
6 changes: 3 additions & 3 deletions colossalai/tensor/comm_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def get_comm_cost(self):
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
# give a tiny cost to shard
backward_communication_cost = 10
backward_communication_cost = 100

if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
Expand All @@ -481,13 +481,13 @@ def get_comm_cost(self):

if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
# give a tiny cost to shard
forward_communication_cost = 10
forward_communication_cost = 100
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)

if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
# no need for axis because all devices are used in mix_gather
forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size)
backward_communication_cost = 10
backward_communication_cost = 100

if self.forward_only:
cost_dict["forward"] = forward_communication_cost
Expand Down