Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .pipeline import PPInferEngine

__all__ = ['PPInferEngine']
__all__ = ["PPInferEngine"]
2 changes: 1 addition & 1 deletion colossalai/inference/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .engine import PPInferEngine

__all__ = ['PPInferEngine']
__all__ = ["PPInferEngine"]
97 changes: 58 additions & 39 deletions colossalai/inference/pipeline/benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
import argparse
import time

import torch
import torch.distributed as dist
import transformers

import colossalai
import time
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
import argparse
GIGABYTE = 1024 ** 3

GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024

colossalai.launch_from_torch(config={})

def data_gen(batch_size: int=4, seq_len: int=512):

def data_gen(batch_size: int = 4, seq_len: int = 512):
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
data = dict(input_ids=input_ids, attention_mask=attention_mask)
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = batch_size
data[k] = v.to('cuda').repeat(*new_shape)
data[k] = v.to("cuda").repeat(*new_shape)
return data


def print_details_info(timestamps, model_config, args, whole_end2end):
if dist.get_rank() == 0:
prefill = []
Expand All @@ -31,32 +35,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
for timestamp in timestamps:
prefill.append(timestamp[1] - timestamp[0])
encoder.append(
sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2))
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
)
end2end.append(timestamp[-1] - timestamp[0])
print(whole_end2end)
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f:
mb_avg_end2end = sum(end2end)/len(end2end)
mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size)
whole_avg_latency = whole_end2end/(args.new_length * args.batch_size)
with open(
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
"w+",
) as f:
mb_avg_end2end = sum(end2end) / len(end2end)
mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size)
whole_avg_latency = whole_end2end / (args.new_length * args.batch_size)
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
if args.dtype in ['fp16','bf16']:
if args.dtype in ["fp16", "bf16"]:
num_bytes = 2
else:
num_bytes = 4

f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n")
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000))
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000))
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000))
f.write(
f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n"
)
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000))
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000))
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000))
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000))
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000))
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000))))
f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12))
f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000))))
f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12))
f.write("----------------------------------------------------------\n")


if torch.cuda.is_available():
current_device = torch.cuda.current_device()

Expand All @@ -66,7 +75,10 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
max_memory_allocated = torch.cuda.max_memory_allocated()
memory_reserved = torch.cuda.memory_reserved()
max_memory_reserved = torch.cuda.max_memory_reserved()
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f:
with open(
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
"a",
) as f:
f.write(
f"\nCurrently using GPU: {current_device}\n"
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
Expand All @@ -77,29 +89,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
)

if __name__ == '__main__':

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='toy', help='the size of model')
parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size')
parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length')
parser.add_argument('--new_length', type=int, default=4, help='new tokens length')
parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size')
parser.add_argument('--pp_size', type=int, default=2, help='pipeline size')
parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log')
parser.add_argument('--dtype', type=str, default='fp16', help='data type')
parser.add_argument("--model", default="toy", help="the size of model")
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
parser.add_argument("--new_length", type=int, default=4, help="new tokens length")
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
parser.add_argument("--pp_size", type=int, default=2, help="pipeline size")
parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log")
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
args = parser.parse_args()

if args.model == 'toy':
if args.model == "toy":
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
elif args.model == '7b':
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf'))
elif args.model == '13b':
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf'))
elif args.model == "7b":
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf"))
elif args.model == "13b":
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf"))
else:
raise NotImplementedError


engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True)

engine = PPInferEngine(
pp_size=args.pp_size,
dtype=args.dtype,
micro_batch_size=args.mb_size,
new_length=args.new_length,
model=model,
model_policy=LlamaForCausalLMPipelinePolicy(),
verbose=True,
)
data = data_gen(args.batch_size, args.seq_len)

torch.cuda.synchronize()
Expand All @@ -109,4 +129,3 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
whole_end2end = time.time() - whole_end2end

print_details_info(timestamps, model.config, args, whole_end2end)

21 changes: 10 additions & 11 deletions colossalai/inference/pipeline/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Callable, List, Optional, Set, Union

import torch
import torch.nn as nn

Expand All @@ -13,7 +11,7 @@


class PPInferEngine:
'''
"""
PPInferEngine is a class that handles the pipeline parallel inference.

Args:
Expand Down Expand Up @@ -41,20 +39,20 @@ class PPInferEngine:
output = engine.inference([tokenized_input])
```

'''
"""

def __init__(
self,
pp_size: int,
dtype: str = 'fp16',
dtype: str = "fp16",
pp_model: nn.Module = None,
model: nn.Module = None,
model_policy: Policy = None,
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
do_sample: bool = False,
num_beams: int = 1,
Expand All @@ -63,15 +61,16 @@ def __init__(
self.pp_size = pp_size
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size,
micro_batch_buffer_size or pp_size)
self.mb_manager = MicroBatchManager(
self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
)
self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)

assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'"
if dtype == 'fp16':
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
if dtype == "fp16":
model.half()
elif dtype == 'bf16':
elif dtype == "bf16":
model.to(torch.bfloat16)
self.model = pp_model or self._shardformer(model, model_policy)

Expand Down
36 changes: 19 additions & 17 deletions colossalai/inference/pipeline/microbatch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

__all__ = 'MicroBatchManager'
__all__ = "MicroBatchManager"


class Status(Enum):
Expand All @@ -13,7 +13,7 @@ class Status(Enum):
COOLDOWN = 4


class MicroBatchDescription():
class MicroBatchDescription:
"""
This is the class to record the infomation of each microbatch, and also do some update operation.
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
Expand All @@ -30,14 +30,14 @@ def __init__(
output_dict: Dict[str, torch.Tensor],
new_length: int,
) -> None:
assert output_dict.get('hidden_states') is not None
self.mb_length = output_dict['hidden_states'].shape[-2]
assert output_dict.get("hidden_states") is not None
self.mb_length = output_dict["hidden_states"].shape[-2]
self.target_length = self.mb_length + new_length
self.kv_cache = ()

def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
if output_dict is not None:
self._update_kvcache(output_dict['past_key_values'])
self._update_kvcache(output_dict["past_key_values"])

def _update_kvcache(self, kv_cache: Tuple):
assert type(kv_cache) == tuple
Expand All @@ -64,7 +64,6 @@ def cur_length(self):
Return the current sequnence length of micro batch

"""
pass


class HeadMicroBatchDescription(MicroBatchDescription):
Expand All @@ -80,13 +79,14 @@ class HeadMicroBatchDescription(MicroBatchDescription):

"""

def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
new_length: int) -> None:
def __init__(
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
) -> None:
super().__init__(inputs_dict, output_dict, new_length)
assert inputs_dict is not None
assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None
self.input_ids = inputs_dict['input_ids']
self.attn_mask = inputs_dict['attention_mask']
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
self.input_ids = inputs_dict["input_ids"]
self.attn_mask = inputs_dict["attention_mask"]
self.new_tokens = None

def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
Expand All @@ -104,7 +104,8 @@ def _update_newtokens(self, new_token: torch.Tensor):

def _update_attnmask(self):
self.attn_mask = torch.cat(
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1)
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
)

@property
def cur_length(self):
Expand All @@ -127,8 +128,9 @@ class BodyMicroBatchDescription(MicroBatchDescription):
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
"""

def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
new_length: int) -> None:
def __init__(
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
) -> None:
super().__init__(inputs_dict, output_dict, new_length)

def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
Expand All @@ -146,8 +148,8 @@ def cur_length(self):
return self.kv_cache[0][0].shape[-2] + 1


class MicroBatchManager():
'''
class MicroBatchManager:
"""
MicroBatchManager is a class that manages the micro batch.

Args:
Expand All @@ -156,7 +158,7 @@ class MicroBatchManager():
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.

'''
"""

def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
self.stage = stage
Expand Down
Loading