-
Couldn't load subscription status.
- Fork 4.5k
[inference] add smoothquant llama attention #4850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Xu-Kai
merged 5 commits into
hpcaitech:feature/smoothquant
from
Xu-Kai:feature/smoothquant
Oct 3, 2023
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| try: | ||
| import torch_int | ||
|
|
||
| HAS_TORCH_INT = True | ||
| except ImportError: | ||
| HAS_TORCH_INT = False | ||
| raise ImportError( | ||
| "Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int" | ||
| ) | ||
|
|
||
| if HAS_TORCH_INT: | ||
| from .smoothquant_layer import LLamaSmoothquantAttention |
185 changes: 185 additions & 0 deletions
185
colossalai/inference/quant/smoothquant/models/smoothquant_layer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| # Code modified from smoothquant: https://github.com/mit-han-lab/smoothquant | ||
|
|
||
| from typing import Optional, Tuple | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T | ||
| from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear | ||
| from transformers.models.llama.modeling_llama import LlamaAttention | ||
|
|
||
| from colossalai.kernel.triton import int8_rotary_embedding_fwd | ||
|
|
||
|
|
||
| class LLamaSmoothquantAttention(nn.Module): | ||
| def __init__( | ||
| self, | ||
| hidden_size: int, | ||
| num_heads: int, | ||
| ): | ||
| super().__init__() | ||
| self.hidden_size = hidden_size | ||
| self.num_heads = num_heads | ||
| self.head_dim = hidden_size // num_heads | ||
|
|
||
| if (self.head_dim * num_heads) != self.hidden_size: | ||
| raise ValueError( | ||
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" | ||
| f" and `num_heads`: {num_heads})." | ||
| ) | ||
|
|
||
| self.attention_weight_scale = 1.0 | ||
|
|
||
| self.qk_bmm = BMM_S8T_S8N_F32T(1.0) | ||
| self.pv_bmm = BMM_S8T_S8N_S8T(1.0) | ||
|
|
||
| self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) | ||
| self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) | ||
| self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) | ||
| self.out_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) | ||
|
|
||
| self.q_output_scale = torch.tensor([1.0]) | ||
| self.k_output_scale = torch.tensor([1.0]) | ||
| self.rotary_output_scale = torch.tensor([1.0]) | ||
|
|
||
| def pack( | ||
yuanheng-zhao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| module: LlamaAttention, | ||
| input_scale: float, | ||
| q_output_scale: float, | ||
| k_output_scale: float, | ||
| v_output_scale: float, | ||
| out_input_scale: float, | ||
| rotary_output_scale: float, | ||
| ): | ||
| int8_module = LLamaSmoothquantAttention(module.hidden_size, module.head_dim) | ||
| int8_module.q_output_scale = q_output_scale | ||
| int8_module.k_output_scale = k_output_scale | ||
| int8_module.rotary_output_scale = rotary_output_scale | ||
| q_output_scale = q_output_scale * module.scaling | ||
| module.q_proj.weight *= module.scaling | ||
| module.q_proj.bias *= module.scaling | ||
| int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, input_scale, q_output_scale) | ||
|
|
||
| int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, input_scale, k_output_scale) | ||
| int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, input_scale, v_output_scale) | ||
| int8_module.out_proj = W8A8BFP32OFP32Linear.from_float(module.out_proj, out_input_scale) | ||
| int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale(q_output_scale, k_output_scale) | ||
|
|
||
| # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 | ||
| int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale(1.0 / 127, v_output_scale, out_input_scale) | ||
| return int8_module | ||
|
|
||
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | ||
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | ||
|
|
||
| @torch.no_grad() | ||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| rotary_emb: Tuple[torch.Tensor], | ||
| key_value_states: Optional[torch.Tensor] = None, | ||
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| layer_head_mask: Optional[torch.Tensor] = None, | ||
| output_attentions: bool = False, | ||
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
| bsz, seq_len, _ = hidden_states.size() | ||
| # get query proj | ||
| query_states = self.q_proj(hidden_states) | ||
| key_states = self.k_proj(hidden_states) | ||
| value_states = self.v_proj(hidden_states) | ||
|
|
||
| cos = rotary_emb[0] | ||
| sin = rotary_emb[1] | ||
| int8_rotary_embedding_fwd( | ||
| query_states.view(-1, self.num_heads, self.head_dim), | ||
| cos, | ||
| sin, | ||
| self.q_output_scale, | ||
| self.rotary_output_scale, | ||
| ) | ||
| int8_rotary_embedding_fwd( | ||
| key_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.k_output_scale, self.rotary_output_scale | ||
| ) | ||
|
|
||
| if past_key_value is not None: | ||
| # reuse k, v, self_attention | ||
| key_states = self._shape(key_states, -1, bsz) | ||
| value_states = self._shape(value_states, -1, bsz) | ||
| key_states = torch.cat([past_key_value[0], key_states], dim=2) | ||
| value_states = torch.cat([past_key_value[1], value_states], dim=2) | ||
| else: | ||
| # self_attention | ||
| key_states = self._shape(key_states, -1, bsz) | ||
| value_states = self._shape(value_states, -1, bsz) | ||
|
|
||
| past_key_value = (key_states, value_states) | ||
|
|
||
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) | ||
|
|
||
| query_states = self._shape(query_states, seq_len, bsz).view(*proj_shape) | ||
| key_states = key_states.view(*proj_shape) | ||
| value_states = value_states.view(*proj_shape) | ||
|
|
||
| src_len = key_states.size(1) | ||
| attn_weights = self.qk_bmm(query_states, key_states) | ||
|
|
||
| if attn_weights.size() != (bsz * self.num_heads, seq_len, src_len): | ||
| raise ValueError( | ||
| f"Attention weights should be of size {(bsz * self.num_heads, seq_len, src_len)}, but is" | ||
| f" {attn_weights.size()}" | ||
| ) | ||
|
|
||
| if attention_mask is not None: | ||
| if attention_mask.size() != (bsz, 1, seq_len, src_len): | ||
| raise ValueError( | ||
| f"Attention mask should be of size {(bsz, 1, seq_len, src_len)}, but is {attention_mask.size()}" | ||
| ) | ||
| attn_weights = attn_weights.view(bsz, self.num_heads, seq_len, src_len) + attention_mask | ||
| attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) | ||
| attn_weights = attn_weights.view(bsz * self.num_heads, seq_len, src_len) | ||
|
|
||
| attn_probs = nn.functional.softmax(attn_weights, dim=-1) | ||
|
|
||
| if layer_head_mask is not None: | ||
| if layer_head_mask.size() != (self.num_heads,): | ||
| raise ValueError( | ||
| f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" | ||
| f" {layer_head_mask.size()}" | ||
| ) | ||
| attn_probs = layer_head_mask.view(1, -1, 1, 1) * attn_probs.view(bsz, self.num_heads, seq_len, src_len) | ||
| attn_probs = attn_probs.view(bsz * self.num_heads, seq_len, src_len) | ||
|
|
||
| if output_attentions: | ||
| # this operation is a bit awkward, but it's required to | ||
| # make sure that attn_weights keeps its gradient. | ||
| # In order to do so, attn_weights have to be reshaped | ||
| # twice and have to be reused in the following | ||
| attn_probs_reshaped = attn_probs.view(bsz, self.num_heads, seq_len, src_len) | ||
| attn_probs = attn_probs_reshaped.view(bsz * self.num_heads, seq_len, src_len) | ||
| else: | ||
| attn_probs_reshaped = None | ||
|
|
||
| # (A_row V_row)_row = (A_row V_col ^T)_row | ||
| attn_probs.mul_(127).round_() | ||
| attn_probs = attn_probs.to(torch.int8) | ||
|
|
||
| value_states = value_states.transpose(1, 2).contiguous() | ||
| attn_output = self.pv_bmm(attn_probs, value_states) | ||
|
|
||
| if attn_output.size() != (bsz * self.num_heads, seq_len, self.head_dim): | ||
| raise ValueError( | ||
| f"`attn_output` should be of size {(bsz, self.num_heads, seq_len, self.head_dim)}, but is" | ||
| f" {attn_output.size()}" | ||
| ) | ||
|
|
||
| attn_output = attn_output.view(bsz, self.num_heads, seq_len, self.head_dim) | ||
| attn_output = attn_output.transpose(1, 2) | ||
|
|
||
| # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be | ||
| # partitioned aross GPUs when using tensor-parallelism. | ||
| attn_output = attn_output.reshape(bsz, seq_len, self.num_heads * self.head_dim).contiguous() | ||
| attn_output = self.out_proj(attn_output) | ||
|
|
||
| return attn_output, attn_probs_reshaped, past_key_value | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| import pytest | ||
| import torch | ||
| from packaging import version | ||
|
|
||
| try: | ||
| from colossalai.kernel.triton import int8_rotary_embedding_fwd | ||
|
|
||
| HAS_TRITON = True | ||
| except ImportError: | ||
| HAS_TRITON = False | ||
| print("please install triton from https://github.com/openai/triton") | ||
|
|
||
| try: | ||
| from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention | ||
|
|
||
| HAS_TORCH_INT = True | ||
| except ImportError: | ||
| HAS_TORCH_INT = False | ||
| print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") | ||
|
|
||
|
|
||
| TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") | ||
Xu-Kai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import math | ||
|
|
||
| import torch | ||
| from torch.nn import functional as F | ||
|
|
||
|
|
||
| def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): | ||
| """ | ||
| adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 | ||
| """ | ||
| xq = xq.view(bs, seqlen, num_head, head_dim) | ||
| xk = xk.view(bs, seqlen, num_head, head_dim) | ||
| xv = xv.view(bs, seqlen, num_head, head_dim) | ||
| mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() | ||
| mask[mask == 0.0] = -100000000.0 | ||
| mask = mask.repeat(bs, num_head, 1, 1) | ||
| keys = xk | ||
| values = xv | ||
| xq = xq.transpose(1, 2) | ||
| keys = keys.transpose(1, 2) | ||
| values = values.transpose(1, 2) | ||
| sm_scale = 1 / math.sqrt(head_dim) | ||
| scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale | ||
| scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float) | ||
|
|
||
| output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) | ||
| return output | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT, | ||
| reason="triton requires cuda version to be higher than 11.4 or not install torch_int", | ||
| ) | ||
| def test_llama_context_attention(): | ||
| head_num = 8 | ||
| seq_len = 32 | ||
| head_dim = 64 | ||
| dtype = torch.float | ||
| hidden_size = head_num * head_dim | ||
|
|
||
| smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num) | ||
|
|
||
| smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) | ||
| smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) | ||
| smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) | ||
| smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) | ||
|
|
||
| smooth_attn = smooth_attn.to("cuda") | ||
|
|
||
| input = torch.randint(-127, 127, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda") | ||
|
|
||
| q = smooth_attn.q_proj(input) | ||
| k = smooth_attn.k_proj(input) | ||
| v = smooth_attn.v_proj(input) | ||
|
|
||
| cos_shape = (seq_len, head_dim // 2) | ||
| cos = torch.ones(cos_shape, dtype=dtype, device="cuda") | ||
| sin = torch.zeros(cos_shape, dtype=dtype, device="cuda") | ||
|
|
||
| in_scale = torch.tensor([1.0], device="cuda") | ||
| out_scale = torch.tensor([1.0], device="cuda") | ||
|
|
||
| int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale, out_scale) | ||
| int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale, out_scale) | ||
|
|
||
| q = q.to(torch.float) | ||
| k = k.to(torch.float) | ||
| v = v.to(torch.float) | ||
| torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim) | ||
| torch_out = (torch_out).to(torch.int8).view(-1, seq_len, head_num * head_dim) | ||
| torch_out = smooth_attn.out_proj(torch_out) | ||
| smooth_out, _, _ = smooth_attn(input, (cos, sin)) | ||
| smooth_out = smooth_out.to(torch.float) | ||
| torch_out = torch_out.to(torch.float) | ||
|
|
||
| assert torch.allclose( | ||
| smooth_out.cpu(), torch_out.cpu(), rtol=1e-2, atol=1e-2 | ||
| ), "outputs from triton and torch are not matched" | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_llama_context_attention() | ||
File renamed without changes.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.