Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions colossalai/inference/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ class InferenceEngine:
Args:
tp_size (int): the size of tensor parallelism.
pp_size (int): the size of pipeline parallelism.
dtype (str): the data type of the model, should be one of 'fp16', 'fp32', 'bf16'.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
micro_batch_size (int): the micro batch size.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. It will be determined by the model type if not provided.
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
max_batch_size (int): the maximum batch size.
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.
quant (str): the quantization method, should be one of 'smoothquant', 'gptq', None.
verbose (bool): whether to return the time cost of each step.

"""

Expand Down
2 changes: 2 additions & 0 deletions colossalai/pipeline/schedule/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.stage_manager.num_stages == 1:
self.microbatch_size = self.batch_size
self.microbatch_offset = 0
assert (
self.batch_size % self.microbatch_size == 0
Expand Down
19 changes: 0 additions & 19 deletions examples/inference/_utils.py

This file was deleted.

167 changes: 0 additions & 167 deletions examples/inference/benchmark.py

This file was deleted.

168 changes: 168 additions & 0 deletions examples/inference/benchmark_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import argparse
import time

import torch
import torch.distributed as dist
import transformers

import colossalai
import colossalai.utils.device as device_utils
from colossalai.inference import InferenceEngine
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils.device import get_current_device

GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024

CONFIG_MAP = {
"toy": transformers.LlamaConfig(num_hidden_layers=4),
"llama-7b": transformers.LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=32,
num_key_value_heads=32,
max_position_embeddings=2048,
),
"llama-13b": transformers.LlamaConfig(
hidden_size=5120,
intermediate_size=13824,
num_attention_heads=40,
num_hidden_layers=40,
num_key_value_heads=40,
max_position_embeddings=2048,
),
"llama2-7b": transformers.LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=32,
num_key_value_heads=32,
max_position_embeddings=4096,
),
"llama2-13b": transformers.LlamaConfig(
hidden_size=5120,
intermediate_size=13824,
num_attention_heads=40,
num_hidden_layers=40,
num_key_value_heads=40,
max_position_embeddings=4096,
),
}


def data_gen(batch_size: int = 4, seq_len: int = 512):
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device())
attention_mask = torch.ones_like(input_ids)
data = dict(input_ids=input_ids, attention_mask=attention_mask)
return data


def print_details_info(outputs, model_config, args, whole_end2end):
msg: str = ""

if dist.get_rank() == 0:
msg += "-------Perf Summary-------\n"
if args.verbose:
timestamps = outputs[1]
prefill = []
encoder = []
end2end = []
for timestamp in timestamps:
prefill.append(timestamp[1] - timestamp[0])
encoder.append(
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
)
end2end.append(timestamp[-1] - timestamp[0])

mb_avg_end2end = sum(end2end) / len(end2end)
mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size)

msg += f"Average prefill time: {sum(prefill) / len(prefill) * 1000:.2f} ms\n"
msg += f"Average encode time: {sum(encoder) / len(encoder) * 1000:.2f} ms\n"
msg += f"Average micro batch end2end time: {mb_avg_end2end * 1000:.2f} ms\n"
msg += f"Average micro batch per token latency: {mb_avg_latency * 1000:.2f} ms\n"

whole_avg_latency = whole_end2end / (args.output_len * args.batch_size)
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
if args.dtype in ["fp16", "bf16"]:
num_bytes = 2
else:
num_bytes = 4

msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n"
msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n"
msg += f"Throughput: {args.output_len * args.batch_size / whole_end2end:.2f} tokens/s\n"
msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"

if torch.cuda.is_available():
msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n"
msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n"
msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n"

print(msg)


def benchmark_inference(args):
config = CONFIG_MAP[args.model]
model = transformers.LlamaForCausalLM(config)
if dist.get_rank() == 0:
print("Model loaded")
engine = InferenceEngine(
pp_size=args.pp_size,
tp_size=args.tp_size,
dtype=args.dtype,
micro_batch_size=args.mb_size,
model=model,
verbose=args.verbose,
max_batch_size=args.batch_size,
max_input_len=args.seq_len,
max_output_len=args.output_len,
)
data = data_gen(args.batch_size, args.seq_len)

N_WARMUP_STEPS = 2

for _ in range(N_WARMUP_STEPS):
engine.generate(data)

torch.cuda.synchronize()
whole_end2end = time.time()
outputs = engine.generate(data)
torch.cuda.synchronize()
whole_end2end = time.time() - whole_end2end

print_details_info(outputs, model.config, args, whole_end2end)


def hybrid_inference(rank, world_size, port, args):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
benchmark_inference(args)


@rerun_if_address_is_in_use()
@clear_cache_before_run()
def benchmark(args):
spawn(hybrid_inference, nprocs=args.tp_size * args.pp_size, args=args)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
default="toy",
help="the size of model",
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
)
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
parser.add_argument("--output_len", type=int, default=128, help="Output length")
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
parser.add_argument("-v", "--verbose", default=False, action="store_true")
args = parser.parse_args()
benchmark(args)
Loading