Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions colossalai/inference/hybridengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@

PP_AXIS, TP_AXIS = 0, 1

_supported_models = [
"LlamaForCausalLM",
"BloomForCausalLM",
]
_supported_models = ["LlamaForCausalLM", "BloomForCausalLM", "LlamaGPTQForCausalLM", "SmoothLlamaForCausalLM"]


class CaiInferEngine:
Expand Down Expand Up @@ -70,12 +67,21 @@ def __init__(
max_batch_size: int = 4,
max_input_len: int = 32,
max_output_len: int = 32,
quant: str = None,
verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
do_sample: bool = False,
num_beams: int = 1,
) -> None:
if quant == "gptq":
from ..quant.gptq import GPTQManager

self.gptq_manager = GPTQManager(model.quantize_config, max_input_len=max_input_len)
model = model.model
elif quant == "smoothquant":
model = model.model

assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
assert (
tp_size * pp_size == dist.get_world_size()
Expand All @@ -85,9 +91,14 @@ def __init__(

assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"

assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
self.pp_size = pp_size
self.tp_size = tp_size
self.quant = quant

if quant == "smoothquant" and dtype != "fp32":
dtype = "fp32"
print("Warning: smoothquant only support fp32 and int8 mix precision. set dtype to fp32")

if dtype == "fp16":
self.dtype = torch.float16
Expand Down Expand Up @@ -118,6 +129,8 @@ def __init__(
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)

self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
if quant == "gptq":
self.gptq_manager.post_init_gptq_buffer(self.model)

def inference(self, input_list):
"""
Expand Down Expand Up @@ -149,6 +162,7 @@ def _shardformer(self, model, model_policy, stage_manager, tp_group):
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
quant=self.quant,
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
Expand All @@ -158,7 +172,7 @@ def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_outp
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
if model.config.model_type == "llama":
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads // self.tp_size
head_num = model.config.num_key_value_heads // self.tp_size
num_hidden_layers = (
model.config.num_hidden_layers
if hasattr(model.config, "num_hidden_layers")
Expand All @@ -171,5 +185,8 @@ def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_outp
num_hidden_layers = model.config.n_layer
layer_num = num_hidden_layers // self.pp_size

cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
if self.quant == "smoothquant":
cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
else:
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
return cache_manager
3 changes: 2 additions & 1 deletion colossalai/inference/hybridengine/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .bloom import BloomInferenceForwards
from .llama import LlamaInferenceForwards

__all__ = ["LlamaInferenceForwards"]
__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards"]
62 changes: 56 additions & 6 deletions colossalai/inference/hybridengine/polices/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ def __init__(self) -> None:

def module_policy(self):
policy = super().module_policy()

if self.shard_config.inference_gptq:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
if self.shard_config.quant == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear

decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
Expand Down Expand Up @@ -94,6 +95,55 @@ def module_policy(self):
],
)

elif self.shard_config.quant == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
ColW8A8BFP32OFP32Linear,
RowW8A8B8O8Linear,
RowW8A8BFP32O32LinearSiLU,
RowW8A8BFP32OFP32Linear,
)

policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=RowW8A8BFP32O32LinearSiLU,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=RowW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
],
)
self.shard_config._infer()

infer_forward = LlamaInferenceForwards.llama_model_forward
Expand Down
1 change: 1 addition & 0 deletions colossalai/inference/quant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .smoothquant.models.llama import SmoothLlamaForCausalLM
1 change: 1 addition & 0 deletions colossalai/inference/quant/gptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

if HAS_AUTO_GPTQ:
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
from .gptq_manager import GPTQManager
61 changes: 61 additions & 0 deletions colossalai/inference/quant/gptq/gptq_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch


class GPTQManager:
def __init__(self, quant_config, max_input_len: int = 1):
self.max_dq_buffer_size = 1
self.max_inner_outer_dim = 1
self.bits = quant_config.bits
self.use_act_order = quant_config.desc_act
self.max_input_len = 1
self.gptq_temp_state_buffer = None
self.gptq_temp_dq_buffer = None
self.quant_config = quant_config

def post_init_gptq_buffer(self, model: torch.nn.Module) -> None:
from .cai_gptq import CaiQuantLinear

HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder

gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False

for name, submodule in model.named_modules():
if isinstance(submodule, CaiQuantLinear):
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)

if self.use_act_order:
self.max_inner_outer_dim = max(
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
)
self.bits = submodule.bits
if not (HAS_GPTQ_CUDA and self.bits == 4):
return

max_input_len = 1
if self.use_act_order:
max_input_len = self.max_input_len
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.gptq_temp_state_buffer = torch.zeros(
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
)
self.gptq_temp_dq_buffer = torch.zeros(
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
)

gptq_cuda.prepare_buffers(
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)

torch.cuda.empty_cache()
4 changes: 1 addition & 3 deletions colossalai/inference/quant/smoothquant/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
raise ImportError(
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
)
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")

if HAS_TORCH_INT:
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
10 changes: 9 additions & 1 deletion colossalai/inference/quant/smoothquant/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from os.path import isdir, isfile, join
from typing import Dict, List, Optional, Union

import accelerate
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -24,6 +23,15 @@
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager

try:
import accelerate

HAS_ACCELERATE = True
except ImportError:
HAS_ACCELERATE = False
print("accelerate is not installed.")


SUPPORTED_MODELS = ["llama"]


Expand Down
26 changes: 18 additions & 8 deletions colossalai/inference/quant/smoothquant/models/linear.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py

import torch
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
from torch_int.functional.quantization import quantize_per_tensor_absmax

try:
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
from torch_int.functional.quantization import quantize_per_tensor_absmax

HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")


try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder

smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except ImportError:
except:
HAS_SMOOTHQUANT_CUDA = False
raise ImportError("CUDA smoothquant linear is not installed")
print("CUDA smoothquant linear is not installed")


class W8A8BFP32O32LinearSiLU(torch.nn.Module):
Expand Down Expand Up @@ -138,21 +146,23 @@ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
)
self.register_buffer(
"bias",
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))

def _apply(self, fn):
# prevent the bias from being converted to half
super()._apply(fn)
self.bias = self.bias.to(torch.float32)
if self.bias is not None:
self.bias = self.bias.to(torch.float32)
return self

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
self.bias = self.bias.to(torch.float32)
if self.bias is not None:
self.bias = self.bias.to(*args, **kwargs)
self.bias = self.bias.to(torch.float32)
return self

@torch.no_grad()
Expand Down
Loading