Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions colossalai/inference/quant/smoothquant/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samp
mean_scale = np.mean([v["input"] for v in act_dict.values()])
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")

# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
Expand Down Expand Up @@ -163,6 +164,7 @@ def stat_input_hook(m, x, y, name):

return act_scales

# Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
@torch.no_grad()
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
if not isinstance(fcs, list):
Expand All @@ -189,6 +191,7 @@ def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
def create_quantized_model(model):
raise NotImplementedError("Not implement create_quantized_model method")

# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
def save_quantized(
self,
save_dir: str,
Expand Down Expand Up @@ -249,6 +252,7 @@ def save_quantized(

self.model.config.save_pretrained(save_dir)

# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
def save_pretrained(
self,
save_dir: str,
Expand All @@ -260,6 +264,7 @@ def save_pretrained(
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)

# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -354,6 +359,7 @@ def skip(*args, **kwargs):

return cls(model, False)

# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
@classmethod
def from_quantized(
cls,
Expand Down
2 changes: 2 additions & 0 deletions colossalai/inference/quant/smoothquant/models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def from_float(module: torch.nn.Linear, input_scale):
return int8_module


# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
class W8A8B8O8Linear(torch.nn.Module):
# For qkv_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
Expand Down Expand Up @@ -117,6 +118,7 @@ def from_float(module: torch.nn.Linear, input_scale, output_scale):
return int8_module


# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
class W8A8BFP32OFP32Linear(torch.nn.Module):
# For fc2 and out_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
Expand Down
3 changes: 3 additions & 0 deletions colossalai/inference/quant/smoothquant/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def forward(self, x, cos, sin, position_ids):
return x_embed


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
def llama_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -559,6 +560,7 @@ def init_to_get_rotary(config, base=10000, use_elem=False):
return _cos_cached, _sin_cached


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def llama_model_forward(
self,
Expand Down Expand Up @@ -729,6 +731,7 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
def __init__(self, model: PreTrainedModel, quantized: bool = False):
super().__init__(model, quantized)

# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
def get_act_dict(
self,
tokenizer,
Expand Down
7 changes: 6 additions & 1 deletion colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"BloomForCausalLM",
"ChatGLMModel",
"ChatGLMForConditionalGeneration",
"LlamaGPTQForCausalLM",
"BloomGPTQForCausalLM",
]


Expand Down Expand Up @@ -213,11 +215,14 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."

model = model.model if self.shard_config.inference_gptq else model

policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy)

if self.shard_config.inference_gptq:
self._post_init_gptq_buffer(model)
self._post_init_gptq_buffer(self.model)

self.model = self.model.cuda()

Expand Down
1 change: 1 addition & 0 deletions colossalai/kernel/triton/gptq_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def cai_gptq_matmul_248_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)


# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
@autotune(
configs=[
triton.Config(
Expand Down
12 changes: 6 additions & 6 deletions colossalai/kernel/triton/smooth_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

if HAS_TRITON:
"""
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
this functions are modified from https://github.com/ModelTC/lightllm
"""

# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
@triton.jit
def _context_flash_attention_kernel(
Q,
Expand Down Expand Up @@ -145,20 +145,16 @@ def _context_flash_attention_kernel(
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return



@torch.no_grad()
def smooth_llama_context_attn_fwd(
q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len
):

BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk, "context process only supports equal query, key, value length"
assert Lk == Lv, "context process only supports equal query, key, value length"
assert Lk in {16, 32, 64, 128}
BLOCK_N = 128
sm_scale = 1.0 / math.sqrt(Lk)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
Expand Down Expand Up @@ -203,6 +199,7 @@ def smooth_llama_context_attn_fwd(
)
return

# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_1_kernel(
Q,
Expand Down Expand Up @@ -264,6 +261,7 @@ def _token_attn_1_kernel(
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
return

# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_1_alibi_kernel(
Q,
Expand Down Expand Up @@ -413,6 +411,7 @@ def token_attn_fwd_1(
)
return

# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
@triton.jit
def _token_attn_softmax_fwd(
softmax_logics,
Expand Down Expand Up @@ -479,6 +478,7 @@ def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen,
)
return

# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_2_kernel(
Prob,
Expand Down
3 changes: 0 additions & 3 deletions examples/inference/gptq_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
Expand Down Expand Up @@ -50,8 +49,6 @@ def run_llama_test(args):
quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
)

init_to_get_rotary(model.model.model, base=10000)

model_config = model.config
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
Expand Down