Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions colossalai/auto_parallel/tensor_shard/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,6 @@ def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader,
pass


def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
'''
This method is used to search the best logical mesh shape for the given world size
based on the alpha_beta_dict.

For example:
if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
'''
# TODO: implement this function
return (world_size, 1)


def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
'''
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
Expand Down Expand Up @@ -127,39 +115,56 @@ def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh


def initialize_device_mesh(world_size: int = -1,
physical_devices: List[int] = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None):
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None):
'''
This method is used to initialize the device mesh.

Args:
world_size(optional): the size of device mesh. If the world_size is -1,
world_size: the size of device mesh. If the world_size is -1,
the world size will be set to the number of GPUs in the current machine.
physical_devices: the physical devices used to initialize the device mesh.
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
generated by profile_alpha_beta function.
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
mesh shape.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
'''
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
device1d = [i for i in range(world_size)]

if physical_devices is None:
physical_devices = [i for i in range(world_size)]
physical_mesh = torch.tensor(physical_devices)

if alpha_beta_dict is None:
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
alpha_beta_dict = profile_alpha_beta(device1d)
ab_profiler = AlphaBetaProfiler(physical_devices)
alpha_beta_dict = ab_profiler.alpha_beta_dict
else:
ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict)

if logical_mesh_shape is None:
if logical_mesh_shape is None and logical_mesh_id is None:
# search for the best logical mesh shape
logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict)
logical_mesh_id = ab_profiler.search_best_logical_mesh()
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
logical_mesh_shape = logical_mesh_id.shape

# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh()

elif logical_mesh_shape is not None and logical_mesh_id is None:
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)

# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)

# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape)
physical_mesh = torch.tensor(device1d)
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
mesh_shape=logical_mesh_shape,
logical_mesh_id=logical_mesh_id,
mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta,
init_process_group=True)
Expand Down Expand Up @@ -224,6 +229,7 @@ def autoparallelize(model: nn.Module,
data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
Expand All @@ -245,6 +251,7 @@ def autoparallelize(model: nn.Module,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
Expand All @@ -254,7 +261,9 @@ def autoparallelize(model: nn.Module,
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
'''
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
logical_mesh_shape=logical_mesh_shape,
logical_mesh_id=logical_mesh_id)
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)

Expand All @@ -263,7 +272,7 @@ def autoparallelize(model: nn.Module,
device_mesh,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solver_solution_path=solver_solution_path,
solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget)

Expand Down
4 changes: 3 additions & 1 deletion colossalai/device/alpha_beta_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ def _extract_alpha_beta(pg, pg_handler):
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
mesh_alpha = [first_latency, second_latency]
mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth]
# The beta values have been enlarged by 1e10 times temporarilly because the computation cost
# is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]

return mesh_alpha, mesh_beta
30 changes: 19 additions & 11 deletions colossalai/device/device_mesh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import operator
from functools import reduce
from typing import List, Tuple

import torch
import torch.distributed as dist
Expand All @@ -15,7 +16,8 @@ class DeviceMesh:

Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view.
logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
mesh_shape (torch.Size, optional): shape of logical view.
mesh_alpha (List[float], optional): coefficients used for computing
communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing
Expand All @@ -28,15 +30,21 @@ class DeviceMesh:
"""

def __init__(self,
physical_mesh_id,
mesh_shape,
mesh_alpha=None,
mesh_beta=None,
init_process_group=False,
need_flatten=True):
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
need_flatten: bool = True):
self.physical_mesh_id = physical_mesh_id
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
if logical_mesh_id is None:
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape

# map global rank into logical rank
self.convert_map = {}
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
Expand All @@ -54,8 +62,8 @@ def __init__(self,
if self.need_flatten and self._logical_mesh_id.dim() > 1:
self.flatten_device_mesh = self.flatten()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
self.mesh_beta)
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
# self.mesh_beta)

@property
def shape(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger

BATCH_SIZE = 8
SEQ_LENGTH = 128
HIDDEN_DIM = 3072
BATCH_SIZE = 16
SEQ_LENGTH = 1024
HIDDEN_DIM = 4096
NUM_HEADS = 16
NUM_LAYERS = 1
NUM_LAYERS = 4
VOCAB_SIZE = 50257
NUM_STEPS = 10
FP16 = False
FP16 = True


def get_cpu_mem():
Expand All @@ -40,7 +40,7 @@ def get_mem_info(prefix=''):

def get_tflops(model_numel, batch_size, seq_len, step_time):
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 8


# Randomly Generated Data
Expand All @@ -66,13 +66,7 @@ def main():
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
}

# Both device mesh initialization and model initialization will be integrated into autoparallelize
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

# Enable auto-parallel
gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True)
gm, solution = autoparallelize(model, meta_input_sample, return_solution=True)

# print solution on rank 0
if gpc.get_global_rank() == 0:
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.