Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions colossalai/inference/quant/smoothquant/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
try:
import torch_int

HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
raise ImportError(
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
)

if HAS_TORCH_INT:
from .smoothquant_layer import LLamaSmoothquantAttention
185 changes: 185 additions & 0 deletions colossalai/inference/quant/smoothquant/models/smoothquant_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Code modified from smoothquant: https://github.com/mit-han-lab/smoothquant

from typing import Optional, Tuple

import torch
from torch import nn
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear
from transformers.models.llama.modeling_llama import LlamaAttention

from colossalai.kernel.triton import int8_rotary_embedding_fwd


class LLamaSmoothquantAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads

if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})."
)

self.attention_weight_scale = 1.0

self.qk_bmm = BMM_S8T_S8N_F32T(1.0)
self.pv_bmm = BMM_S8T_S8N_S8T(1.0)

self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size)
self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size)
self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size)
self.out_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size)

self.q_output_scale = torch.tensor([1.0])
self.k_output_scale = torch.tensor([1.0])
self.rotary_output_scale = torch.tensor([1.0])

def pack(
self,
module: LlamaAttention,
input_scale: float,
q_output_scale: float,
k_output_scale: float,
v_output_scale: float,
out_input_scale: float,
rotary_output_scale: float,
):
int8_module = LLamaSmoothquantAttention(module.hidden_size, module.head_dim)
int8_module.q_output_scale = q_output_scale
int8_module.k_output_scale = k_output_scale
int8_module.rotary_output_scale = rotary_output_scale
q_output_scale = q_output_scale * module.scaling
module.q_proj.weight *= module.scaling
module.q_proj.bias *= module.scaling
int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, input_scale, q_output_scale)

int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, input_scale, k_output_scale)
int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, input_scale, v_output_scale)
int8_module.out_proj = W8A8BFP32OFP32Linear.from_float(module.out_proj, out_input_scale)
int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale(q_output_scale, k_output_scale)

# alpha = s_prob * s_v / s_out, where s_prob = 1 / 127
int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale(1.0 / 127, v_output_scale, out_input_scale)
return int8_module

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
rotary_emb: Tuple[torch.Tensor],
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, seq_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

cos = rotary_emb[0]
sin = rotary_emb[1]
int8_rotary_embedding_fwd(
query_states.view(-1, self.num_heads, self.head_dim),
cos,
sin,
self.q_output_scale,
self.rotary_output_scale,
)
int8_rotary_embedding_fwd(
key_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.k_output_scale, self.rotary_output_scale
)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(key_states, -1, bsz)
value_states = self._shape(value_states, -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(key_states, -1, bsz)
value_states = self._shape(value_states, -1, bsz)

past_key_value = (key_states, value_states)

proj_shape = (bsz * self.num_heads, -1, self.head_dim)

query_states = self._shape(query_states, seq_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)

src_len = key_states.size(1)
attn_weights = self.qk_bmm(query_states, key_states)

if attn_weights.size() != (bsz * self.num_heads, seq_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, seq_len, src_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, seq_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, seq_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, seq_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = attn_weights.view(bsz * self.num_heads, seq_len, src_len)

attn_probs = nn.functional.softmax(attn_weights, dim=-1)

if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_probs = layer_head_mask.view(1, -1, 1, 1) * attn_probs.view(bsz, self.num_heads, seq_len, src_len)
attn_probs = attn_probs.view(bsz * self.num_heads, seq_len, src_len)

if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_probs_reshaped = attn_probs.view(bsz, self.num_heads, seq_len, src_len)
attn_probs = attn_probs_reshaped.view(bsz * self.num_heads, seq_len, src_len)
else:
attn_probs_reshaped = None

# (A_row V_row)_row = (A_row V_col ^T)_row
attn_probs.mul_(127).round_()
attn_probs = attn_probs.to(torch.int8)

value_states = value_states.transpose(1, 2).contiguous()
attn_output = self.pv_bmm(attn_probs, value_states)

if attn_output.size() != (bsz * self.num_heads, seq_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, seq_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.view(bsz, self.num_heads, seq_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, seq_len, self.num_heads * self.head_dim).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_probs_reshaped, past_key_value
18 changes: 8 additions & 10 deletions colossalai/kernel/triton/int8_rotary_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,15 @@ def _rotary_kernel(

cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
in_scale = tl.load(input_scale)
o_scale = tl.load(output_scale)

q0 = q0.to(tl.float32) * in_scale
q1 = q1.to(tl.float32) * in_scale
q0 = q0.to(tl.float32) * input_scale
q1 = q1.to(tl.float32) * input_scale

out0 = (q0 * cos - q1 * sin) / o_scale
out1 = (q0 * sin + q1 * cos) / o_scale
out0 = (q0 * cos - q1 * sin) / output_scale
out1 = (q0 * sin + q1 * cos) / output_scale

# out0 = out0.to(tl.int8)
# out1 = out1.to(tl.int8)
out0 = out0.to(tl.int8)
out1 = out1.to(tl.int8)

tl.store(
q + off_q0,
Expand Down Expand Up @@ -99,8 +97,8 @@ def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale):

_rotary_kernel[grid](
q,
input_scale,
output_scale,
input_scale.item(),
output_scale.item(),
cos,
sin,
q.stride(0),
Expand Down
105 changes: 105 additions & 0 deletions tests/test_smoothquant/test_llama_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pytest
import torch
from packaging import version

try:
from colossalai.kernel.triton import int8_rotary_embedding_fwd

HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")

try:
from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention

HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")


TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")

import math

import torch
from torch.nn import functional as F


def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
"""
adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
"""
xq = xq.view(bs, seqlen, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
mask[mask == 0.0] = -100000000.0
mask = mask.repeat(bs, num_head, 1, 1)
keys = xk
values = xv
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
sm_scale = 1 / math.sqrt(head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float)

output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
return output


@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT,
reason="triton requires cuda version to be higher than 11.4 or not install torch_int",
)
def test_llama_context_attention():
head_num = 8
seq_len = 32
head_dim = 64
dtype = torch.float
hidden_size = head_num * head_dim

smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num)

smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8)
smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8)
smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8)
smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8)

smooth_attn = smooth_attn.to("cuda")

input = torch.randint(-127, 127, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda")

q = smooth_attn.q_proj(input)
k = smooth_attn.k_proj(input)
v = smooth_attn.v_proj(input)

cos_shape = (seq_len, head_dim // 2)
cos = torch.ones(cos_shape, dtype=dtype, device="cuda")
sin = torch.zeros(cos_shape, dtype=dtype, device="cuda")

in_scale = torch.tensor([1.0], device="cuda")
out_scale = torch.tensor([1.0], device="cuda")

int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale, out_scale)
int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale, out_scale)

q = q.to(torch.float)
k = k.to(torch.float)
v = v.to(torch.float)
torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim)
torch_out = (torch_out).to(torch.int8).view(-1, seq_len, head_num * head_dim)
torch_out = smooth_attn.out_proj(torch_out)
smooth_out, _, _ = smooth_attn(input, (cos, sin))
smooth_out = smooth_out.to(torch.float)
torch_out = torch_out.to(torch.float)

assert torch.allclose(
smooth_out.cpu(), torch_out.cpu(), rtol=1e-2, atol=1e-2
), "outputs from triton and torch are not matched"


if __name__ == "__main__":
test_llama_context_attention()