Skip to content

Commit 175fd26

Browse files
committed
ptensor
1 parent 14a4342 commit 175fd26

File tree

9 files changed

+326
-133
lines changed

9 files changed

+326
-133
lines changed

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 76 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from functools import reduce
55
from pathlib import Path
66
from shutil import rmtree
7-
from typing import Dict, Iterator, Optional, OrderedDict, Set, Tuple
7+
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
88

99
import torch
1010
import torch.distributed as dist
@@ -14,6 +14,7 @@
1414

1515
from colossalai.cluster import DistCoordinator
1616
from colossalai.interface import ModelWrapper, OptimizerWrapper
17+
from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
1718
from colossalai.utils import get_current_device
1819

1920
from .general_checkpoint_io import GeneralCheckpointIO
@@ -77,39 +78,39 @@ def __init__(
7778
self.verbose = verbose
7879
self.coordinator = DistCoordinator()
7980

80-
@staticmethod
81-
def _named_modules(
82-
module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
83-
):
84-
r"""Returns an iterator over all leaf modules in the network, yielding
85-
both the name of the module as well as the module itself.
86-
87-
Args:
88-
memo: a memo to store the set of modules already added to the result
89-
prefix: a prefix that will be added to the name of the module
90-
remove_duplicate: whether to remove the duplicated module instances in the result
91-
or not
92-
93-
Yields:
94-
(str, Module): Tuple of name and module
95-
96-
Note:
97-
Duplicate modules are returned only once. In the following
98-
example, ``l`` will be returned only once.
99-
"""
100-
if memo is None:
101-
memo = set()
102-
103-
if module not in memo:
104-
sub_modules = [(name, subm) for (name, subm) in module._modules.items() if subm is not None]
105-
if len(sub_modules) == 0:
106-
if remove_duplicate:
107-
memo.add(module)
108-
yield prefix, module
109-
else:
110-
for name, subm in sub_modules:
111-
submodule_prefix = prefix + ("." if prefix else "") + name
112-
yield from HybridParallelCheckpointIO._named_modules(subm, memo, submodule_prefix, remove_duplicate)
81+
# @staticmethod
82+
# def _named_modules(
83+
# module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
84+
# ):
85+
# r"""Returns an iterator over all leaf modules in the network, yielding
86+
# both the name of the module as well as the module itself.
87+
88+
# Args:
89+
# memo: a memo to store the set of modules already added to the result
90+
# prefix: a prefix that will be added to the name of the module
91+
# remove_duplicate: whether to remove the duplicated module instances in the result
92+
# or not
93+
94+
# Yields:
95+
# (str, Module): Tuple of name and module
96+
97+
# Note:
98+
# Duplicate modules are returned only once. In the following
99+
# example, ``l`` will be returned only once.
100+
# """
101+
# if memo is None:
102+
# memo = set()
103+
104+
# if module not in memo:
105+
# sub_modules = [(name, subm) for (name, subm) in module._modules.items() if subm is not None]
106+
# if len(sub_modules) == 0:
107+
# if remove_duplicate:
108+
# memo.add(module)
109+
# yield prefix, module
110+
# else:
111+
# for name, subm in sub_modules:
112+
# submodule_prefix = prefix + ("." if prefix else "") + name
113+
# yield from HybridParallelCheckpointIO._named_modules(subm, memo, submodule_prefix, remove_duplicate)
113114

114115
@staticmethod
115116
def _model_sharder(
@@ -120,18 +121,29 @@ def _model_sharder(
120121
state_dict_sharder = StateDictSharder(size_per_shard)
121122

122123
# Save parameters.
123-
for module_name, module in HybridParallelCheckpointIO._named_modules(model):
124-
state_dicts = module.state_dict()
125-
for name, param in state_dicts.items():
126-
if param is None:
127-
continue
128-
# Gather tensor pieces when using tensor parallel.
129-
param_ = gather_distributed_param(param, keep_vars=False)
130-
if module_name != "":
131-
module_name = module_name + "."
132-
block, block_size = state_dict_sharder.append_param(module_name + name, param_)
133-
if block is not None:
134-
yield block, block_size
124+
# for module_name, module in HybridParallelCheckpointIO._named_modules(model):
125+
# state_dicts = module.state_dict()
126+
# for name, param in state_dicts.items():
127+
# if param is None:
128+
# continue
129+
# # Gather tensor pieces when using tensor parallel.
130+
# param_ = gather_distributed_param(param, keep_vars=False)
131+
# if module_name != "":
132+
# module_name = module_name + "."
133+
# block, block_size = state_dict_sharder.append_param(module_name + name, param_)
134+
# if block is not None:
135+
# yield block, block_size
136+
for name, param in model.named_parameters():
137+
if param is None:
138+
continue
139+
# Gather tensor pieces when using tensor parallel.
140+
if is_padded_tensor(param):
141+
print("bbbbbbbbbbbbbbbbbbbbbbbbbb")
142+
param = to_unpadded_tensor(param)
143+
param_ = gather_distributed_param(param, keep_vars=False)
144+
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
145+
if block is not None:
146+
yield block, block_size
135147

136148
# Save buffers.
137149
for name, buf in model.named_buffers():
@@ -906,7 +918,13 @@ def gather_from_sharded_optimizer_state(
906918
dist.all_gather(gather_tensor, v, group=tp_group)
907919
v = torch.cat(gather_tensor, dim=partition_dim)
908920

909-
state_[k] = v.detach().clone()[: original_shape[0], ...].to(device)
921+
padding_dim = search_padding_dim(v.shape, original_shape)
922+
if padding_dim is not None:
923+
print("cccccccccccec")
924+
v = init_as_ptensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
925+
v = to_unpadded_tensor(v)
926+
927+
state_[k] = v.detach().clone().to(device)
910928

911929
return state_
912930

@@ -949,15 +967,17 @@ def shard_from_complete_optimizer_state(
949967

950968
padding_dim = search_padding_dim(global_shape, original_shape)
951969
if padding_dim is not None:
952-
padding_size = global_shape[padding_dim] - original_shape[padding_dim]
953-
padding_data = torch.zeros(
954-
*v.shape[:padding_dim],
955-
padding_size,
956-
*v.shape[padding_dim + 1 :],
957-
device=v.device,
958-
dtype=v.dtype,
959-
)
960-
v = torch.cat((v, padding_data), dim=padding_dim).contiguous()
970+
# padding_size = global_shape[padding_dim] - original_shape[padding_dim]
971+
# padding_data = torch.zeros(
972+
# *v.shape[:padding_dim],
973+
# padding_size,
974+
# *v.shape[padding_dim + 1 :],
975+
# device=v.device,
976+
# dtype=v.dtype,
977+
# )
978+
# v = torch.cat((v, padding_data), dim=padding_dim).contiguous()
979+
print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
980+
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)
961981

962982
if partition_dim is not None:
963983
slice_size = current_shape[partition_dim]

colossalai/shardformer/layer/parallel_module.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
is_distributed_tensor,
2121
sharded_tensor_to_param,
2222
)
23+
from colossalai.tensor.p_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
2324

2425
__all__ = ["ParallelModule"]
2526

@@ -230,10 +231,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
230231
for name, param in self._parameters.items():
231232
if param is not None:
232233
param = gather_distributed_param(param, keep_vars=keep_vars)
233-
if self.new_num_embeddings > self.old_num_embeddings:
234-
destination[prefix + name] = param[: self.old_num_embeddings, ...].data
235-
else:
236-
destination[prefix + name] = param.data
234+
# if self.new_num_embeddings > self.old_num_embeddings:
235+
# destination[prefix + name] = param[: self.old_num_embeddings, ...].data
236+
# else:
237+
# destination[prefix + name] = param.data
238+
if is_padded_tensor(param):
239+
param = to_unpadded_tensor(param)
240+
destination[prefix + name] = param.data
237241

238242
for name, buf in self._buffers.items():
239243
if buf is not None and name not in self._non_persistent_buffers_set:
@@ -296,12 +300,15 @@ def _load_from_state_dict(
296300
)
297301
continue
298302

299-
if self.new_num_embeddings > self.old_num_embeddings:
300-
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
301-
padding_embeddings = torch.zeros(
302-
num_padding_tokens, *input_param.shape[1:], device=input_param.device, dtype=input_param.dtype
303-
)
304-
input_param.data = torch.cat((input_param.data, padding_embeddings), dim=0).contiguous()
303+
# if self.new_num_embeddings > self.old_num_embeddings:
304+
# num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
305+
# padding_embeddings = torch.zeros(
306+
# num_padding_tokens, *input_param.shape[1:], device=input_param.device, dtype=input_param.dtype
307+
# )
308+
# input_param.data = torch.cat((input_param.data, padding_embeddings), dim=0).contiguous()
309+
if is_padded_tensor(param):
310+
print("is_padded_tensor(param)", is_padded_tensor(param))
311+
input_param = to_padded_tensor(input_param, param.current_length, param.padding_dim)
305312

306313
if is_distributed_tensor(param):
307314
# shard the input param
@@ -359,16 +366,19 @@ def _load_from_state_dict(
359366
unexpected_keys.append(key)
360367

361368
def resize_embedding_weight(self):
362-
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
363-
valid_weight = self.weight.data
364-
padding_weight = torch.zeros(
365-
num_padding_tokens, *self.weight.shape[1:], device=self.weight.device, dtype=self.weight.dtype
366-
)
367-
# padding to embedding
368-
self.weight.data = torch.cat((valid_weight, padding_weight), dim=0).contiguous()
369+
self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0)
370+
# num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
371+
# valid_weight = self.weight.data
372+
# padding_weight = torch.zeros(
373+
# num_padding_tokens, *self.weight.shape[1:], device=self.weight.device, dtype=self.weight.dtype
374+
# )
375+
# # padding to embedding
376+
# self.weight.data = torch.cat((valid_weight, padding_weight), dim=0).contiguous()
369377

370378
def resize_embedding_bias(self):
371-
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
372-
valid_bias = self.bias.data
373-
padding_bias = torch.zeros((num_padding_tokens), device=self.bias.device, dtype=self.bias.dtype)
374-
self.bias.data = torch.cat((valid_bias, padding_bias), dim=0).contiguous()
379+
print("resize bias")
380+
self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0)
381+
# num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
382+
# valid_bias = self.bias.data
383+
# padding_bias = torch.zeros((num_padding_tokens), device=self.bias.device, dtype=self.bias.dtype)
384+
# self.bias.data = torch.cat((valid_bias, padding_bias), dim=0).contiguous()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .api import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
2+
3+
__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_ptensor"]

colossalai/tensor/p_tensor/api.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import torch
2+
3+
4+
def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
5+
"""
6+
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
7+
8+
Args:
9+
tensor (torch.Tensor): The tensor to be hijacked.
10+
11+
Returns:
12+
torch.Tensor: The hijacked tensor.
13+
"""
14+
ptensor._unpad_detach = ptensor.detach
15+
ptensor._unpad_clone = ptensor.clone
16+
17+
def new_detach(self):
18+
t_ = self._unpad_detach()
19+
t_.padding_dim = self.padding_dim
20+
t_.origin_length = self.origin_length
21+
t_.current_length = self.current_length
22+
return t_
23+
24+
def new_clone(self, *args, **kwargs):
25+
t_ = self._unpad_clone(*args, **kwargs)
26+
t_.padding_dim = self.padding_dim
27+
t_.origin_length = self.origin_length
28+
t_.current_length = self.current_length
29+
return t_
30+
31+
# bind the new methods to the tensor
32+
ptensor.detach = new_detach.__get__(ptensor)
33+
ptensor.clone = new_clone.__get__(ptensor)
34+
return ptensor
35+
36+
37+
def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
38+
"""
39+
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
40+
41+
Args:
42+
tensor (torch.Tensor): The tensor to be hijacked.
43+
44+
Returns:
45+
torch.Tensor: The hijacked tensor.
46+
"""
47+
ptensor.detach = ptensor._unpad_detach
48+
ptensor.clone = ptensor._unpad_clone
49+
50+
delattr(ptensor, "_unpad_detach")
51+
delattr(ptensor, "_unpad_clone")
52+
53+
return ptensor
54+
55+
56+
def is_padded_tensor(tensor: torch.Tensor) -> bool:
57+
"""
58+
Check whether the given tensor is a padding tensor.
59+
60+
Args:
61+
tensor (torch.Tensor): The tensor to be checked.
62+
63+
Returns:
64+
bool: Whether the given tensor is a padding tensor.
65+
"""
66+
return hasattr(tensor, "padding_dim")
67+
68+
69+
def to_padded_tensor(
70+
tensor: torch.Tensor,
71+
current_length: int,
72+
padding_dim: int,
73+
) -> torch.Tensor:
74+
assert (
75+
padding_dim < tensor.dim()
76+
), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}"
77+
78+
if is_padded_tensor(tensor):
79+
return tensor
80+
81+
origin_length = tensor.shape[padding_dim]
82+
padding_num = current_length - origin_length
83+
padding_data = torch.zeros(
84+
*tensor.shape[:padding_dim],
85+
padding_num,
86+
*tensor.shape[padding_dim + 1 :],
87+
device=tensor.device,
88+
dtype=tensor.dtype,
89+
)
90+
tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous()
91+
92+
setattr(tensor, "padding_dim", padding_dim)
93+
setattr(tensor, "origin_length", origin_length)
94+
setattr(tensor, "current_length", current_length)
95+
96+
_hijack_detach_and_clone(tensor)
97+
98+
return tensor
99+
100+
101+
def to_unpadded_tensor(ptensor: torch.Tensor):
102+
print("ptensor", ptensor.shape)
103+
if not is_padded_tensor(ptensor):
104+
return ptensor
105+
106+
unpad_slices = [slice(None)] * ptensor.dim()
107+
unpad_slices[ptensor.padding_dim] = slice(None, ptensor.origin_length)
108+
tensor = ptensor[tuple(unpad_slices)]
109+
110+
delattr(ptensor, "padding_dim")
111+
delattr(ptensor, "origin_length")
112+
delattr(ptensor, "current_length")
113+
114+
_hijack_back_detach_and_clone(ptensor)
115+
116+
return tensor
117+
118+
119+
def init_as_ptensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
120+
if is_padded_tensor(tensor):
121+
return tensor
122+
123+
setattr(tensor, "padding_dim", padding_dim)
124+
setattr(tensor, "origin_length", origin_length)
125+
setattr(tensor, "current_length", current_length)
126+
127+
_hijack_detach_and_clone(tensor)
128+
129+
return tensor

0 commit comments

Comments
 (0)