Skip to content

Commit fb85479

Browse files
committed
support mode 2 sp in gpt2 (hpcaitech#5)
* [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2
1 parent 2e82701 commit fb85479

File tree

6 files changed

+198
-80
lines changed

6 files changed

+198
-80
lines changed

colossalai/shardformer/layer/_operation.py

Lines changed: 171 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -167,47 +167,57 @@ def backward(ctx, grad_output):
167167
return grad_input, grad_weight, grad_bias, None, None, None
168168

169169

170-
def _AllgatherLinear(input_, weight, process_group):
170+
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
171+
# currently only support one single tensor as output
171172
group_size = dist.get_world_size(process_group)
172173
cur_rank = dist.get_rank(process_group)
173174

174-
input_shape = input_.shape
175-
weight_shape = weight.shape
176-
177-
output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
175+
#output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
178176

179177
# initialization of ring communication
180-
input_shape[1]
181178
recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
182179
send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
183-
recv_tensor = input_.clone()
184-
send_tensor = input_.clone()
185-
186-
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
187-
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
188-
handles = dist.batch_isend_irecv([send_op, recv_op])
180+
recv_tensors = {}
181+
send_tensors = {}
182+
for k, v in input_to_gather.items():
183+
recv_tensors[k] = v.clone()
184+
send_tensors[k] = v.clone()
185+
186+
def communicate_step():
187+
comm_ops = []
188+
for k in recv_tensors:
189+
comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
190+
comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
191+
return dist.batch_isend_irecv(comm_ops)
192+
193+
def switch_step():
194+
for k in recv_tensors:
195+
tmp_tensor = send_tensors[k]
196+
send_tensors[k] = recv_tensors[k]
197+
recv_tensors[k] = tmp_tensor
198+
199+
output_tensors = []
200+
201+
handles = communicate_step()
189202
# first round: special case, retrive from local tensor
190-
output_tensors[0] = F.linear(input_, weight)
203+
output_tensors.append(func(**input_to_gather, **input_local))
191204
for i in range(group_size - 2):
192205
for handle in handles:
193206
handle.wait()
194207

195-
tmp_tensor = send_tensor
196-
send_tensor = recv_tensor
197-
recv_tensor = tmp_tensor
208+
switch_step()
198209

199-
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
200-
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
201-
handles = dist.batch_isend_irecv([recv_op, send_op])
210+
handles = communicate_step()
202211

203212
# actual computation
204-
output_tensors[i + 1] = F.linear(send_tensor, weight)
213+
output_tensors.append(func(**send_tensors, **input_local))
205214

206215
# final round: special case, no need to send/recv again
207216
for handle in handles:
208217
handle.wait()
209-
output_tensors[group_size - 1] = F.linear(recv_tensor, weight)
210-
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1)
218+
output_tensors.append(func(**recv_tensors, **input_local))
219+
220+
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
211221

212222

213223
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
@@ -247,6 +257,41 @@ def backward(ctx, grad_output):
247257
return output, None, None
248258

249259

260+
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
261+
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
262+
263+
Args:
264+
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
265+
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
266+
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
267+
268+
"""
269+
270+
@staticmethod
271+
def forward(ctx, input_, process_group, dim):
272+
ctx.process_group = process_group
273+
ctx.dim = dim
274+
275+
return _gather(input_, dim, process_group)
276+
277+
@staticmethod
278+
def backward(ctx, grad_output):
279+
dim = ctx.dim
280+
process_group = ctx.process_group
281+
282+
# do reduce-scatter
283+
new_shape = list(grad_output.shape)
284+
assert (
285+
new_shape[dim] % dist.get_world_size(process_group) == 0
286+
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
287+
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
288+
grad_list = [item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)]
289+
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
290+
dist.reduce_scatter(output, grad_list, group=process_group)
291+
292+
return output, None, None
293+
294+
250295
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
251296
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
252297
@@ -258,19 +303,35 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
258303
"""
259304

260305
@staticmethod
261-
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
306+
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
262307
ctx.save_for_backward(input_, weight, bias)
263308
ctx.use_bias = bias is not None
264309
ctx.process_group = process_group
265310
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
266311
ctx.dim = dim
267312
ctx.overlap = overlap
268313

269-
if bias is not None:
270-
input_parallel = _gather(input_, dim, process_group)
271-
output = F.linear(input_parallel, weight, bias)
314+
if ring is True:
315+
input_to_gather = {}
316+
input_local = {}
317+
input_to_gather['input'] = input_
318+
input_local['weight'] = weight
319+
320+
output = _ring_as_gather(
321+
F.linear,
322+
input_to_gather=input_to_gather,
323+
input_local=input_local,
324+
process_group=process_group,
325+
)
326+
327+
if bias is not None:
328+
output += bias
272329
else:
273-
output = _AllgatherLinear(input_, weight, process_group)
330+
input_parallel = _gather(input_, dim, process_group)
331+
if bias is not None:
332+
output = F.linear(input_parallel, weight, bias)
333+
else:
334+
output = F.linear(input_parallel, weight)
274335

275336
return output
276337

@@ -373,34 +434,43 @@ def backward(ctx, grad_output):
373434
# wait until reduce-scatter finished
374435
reducescatter_handle.wait()
375436

376-
return output, grad_weight, grad_bias, None, None, None, None
437+
return output, grad_weight, grad_bias, None, None, None, None, None
377438

378439

379-
def _ReduceScatterLinear(input_, weight, process_group):
440+
def _ring_as_reducescatter(func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1):
441+
# currently only support one single tensor as output
380442
group_size = dist.get_world_size(process_group)
381443
cur_rank = dist.get_rank(process_group)
382444

383-
input_shape = input_.shape
384-
385445
# initialization of ring communication
386-
# communicate(e.g.): 0->1->2->3
387-
# compute(e.g.): 3->2->1->0
388-
input_tensors = list(torch.split(input_, int(input_shape[1] / group_size), dim=1))
389-
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
390-
input_tensors.reverse()
391446
recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
392447
send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
448+
input_tensors = []
449+
for _ in range(group_size):
450+
input_tensors.append({})
451+
for k, v in input_to_reducescatter.items():
452+
input_shape = v.shape
453+
assert input_shape[reducescatter_dim] % group_size == 0
454+
_input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
455+
for i in range(group_size):
456+
input_tensors[i][k] = _input_tensors[i]
457+
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
458+
input_tensors.reverse()
393459

394-
# first round: special case, no reduce operation
395-
output_tensor = F.linear(input_tensors[0], weight)
460+
output_tensor = func(**input_tensors[0], **input_local)
396461
recv_tensor = output_tensor.clone()
397462
send_tensor = output_tensor.clone()
398-
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
399-
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
400-
handles = dist.batch_isend_irecv([recv_op, send_op])
463+
464+
def communicate_step():
465+
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
466+
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
467+
return dist.batch_isend_irecv([recv_op, send_op])
468+
469+
handles = communicate_step()
470+
# first round: special case, retrive from local tensor
401471
for i in range(group_size - 2):
402472
# actual computation
403-
output_tensor = F.linear(input_tensors[i + 1], weight)
473+
output_tensor = func(**input_tensors[i + 1], **input_local)
404474

405475
for handle in handles:
406476
handle.wait()
@@ -410,12 +480,10 @@ def _ReduceScatterLinear(input_, weight, process_group):
410480
send_tensor = output_tensor
411481
output_tensor = tmp_tensor
412482

413-
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
414-
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
415-
handles = dist.batch_isend_irecv([recv_op, send_op])
483+
handles = communicate_step()
416484

417485
# final round: special case, no need to send/recv again
418-
output_tensor = F.linear(input_tensors[group_size - 1], weight)
486+
output_tensor = func(**input_tensors[-1], **input_local)
419487
for handle in handles:
420488
handle.wait()
421489
output_tensor += recv_tensor
@@ -433,27 +501,44 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
433501
"""
434502

435503
@staticmethod
436-
def forward(ctx, input_, weight, bias, process_group, dim):
504+
def forward(ctx, input_, weight, bias, process_group, dim, ring):
437505
ctx.save_for_backward(input_, weight, bias)
438506
ctx.use_bias = bias is not None
439507
ctx.process_group = process_group
440508
ctx.dim = dim
441-
if bias is not None:
442-
partial_output = F.linear(input_, weight, bias)
509+
510+
if ring is True:
511+
input_to_reducescatter = {}
512+
input_local = {}
513+
input_to_reducescatter['input'] = input_
514+
input_local['weight'] = weight
515+
516+
if bias is not None:
517+
input_to_reducescatter['bias'] = bias
518+
519+
output = _ring_as_reducescatter(
520+
F.linear,
521+
input_to_reducescatter=input_to_reducescatter,
522+
input_local=input_local,
523+
process_group=process_group,
524+
)
443525
else:
444-
return _ReduceScatterLinear(input_, weight, process_group)
526+
if bias is not None:
527+
partial_output = F.linear(input_, weight, bias)
528+
else:
529+
partial_output = F.linear(input_, weight)
445530

446-
output_shape = list(partial_output.shape)
447-
assert (
448-
output_shape[dim] % dist.get_world_size(process_group) == 0
449-
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
450-
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
531+
output_shape = list(partial_output.shape)
532+
assert (
533+
output_shape[dim] % dist.get_world_size(process_group) == 0
534+
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
535+
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
451536

452-
output_list = [
453-
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
454-
]
455-
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
456-
dist.reduce_scatter(output, output_list, group=process_group)
537+
output_list = [
538+
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
539+
]
540+
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
541+
dist.reduce_scatter(output, output_list, group=process_group)
457542

458543
return output
459544

@@ -481,7 +566,7 @@ def backward(ctx, grad_output):
481566
grad_weight = grad_output.t().matmul(total_input)
482567
grad_bias = grad_output.sum(dim=0) if use_bias else None
483568

484-
return grad_input, grad_weight, grad_bias, None, None
569+
return grad_input, grad_weight, grad_bias, None, None, None
485570

486571

487572
class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
@@ -530,17 +615,32 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
530615
"""
531616

532617
@staticmethod
533-
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
618+
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
534619
ctx.save_for_backward(input_, weight, bias)
535620
ctx.use_bias = bias is not None
536621
ctx.process_group = process_group
537622
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
538623
ctx.dim = dim
539624
ctx.overlap = overlap
540625

541-
input_parallel = _gather(input_, dim, process_group)
626+
if ring is True:
627+
input_to_gather = {}
628+
input_local = {}
629+
input_to_gather['input'] = input_
630+
input_local['other'] = weight
631+
632+
output = _ring_as_gather(
633+
torch.matmul,
634+
input_to_gather=input_to_gather,
635+
input_local=input_local,
636+
process_group=process_group,
637+
gather_dim=dim
638+
)
639+
640+
else:
641+
input_parallel = _gather(input_, dim, process_group)
542642

543-
output = torch.matmul(input_parallel, weight)
643+
output = torch.matmul(input_parallel, weight)
544644

545645
if bias is not None:
546646
output = output + bias
@@ -620,7 +720,7 @@ def backward(ctx, grad_output):
620720
# wait until reduce-scatter finished
621721
reducescatter_handle.wait()
622722

623-
return output, grad_weight, grad_bias, None, None, None, None
723+
return output, grad_weight, grad_bias, None, None, None, None, None
624724

625725

626726
class _SplitForwardGatherBackward(torch.autograd.Function):
@@ -873,10 +973,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
873973

874974

875975
def linear_gather_forward_reducescatter_backward(
876-
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
976+
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
877977
):
878978
return _LinearWithGatherForwardReduceScatterBackward.apply(
879-
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
979+
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
880980
)
881981

882982

@@ -888,15 +988,15 @@ def reducescatter_forward_gather_backward(input_, process_group, dim):
888988
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
889989

890990

891-
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1):
892-
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim)
991+
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
992+
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
893993

894994

895995
def matmul_gather_forward_reducescatter_backward(
896-
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
996+
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
897997
):
898998
return _MatmulWithGatherForwardReduceScatterBackward.apply(
899-
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
999+
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
9001000
)
9011001

9021002

0 commit comments

Comments
 (0)