Skip to content
Merged
15 changes: 13 additions & 2 deletions colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ In this section we discuss how the colossal inference works and integrates with
- [x] policy
- [x] context forward
- [x] token forward
- [] support flash-decoding
- [ ] Replace the kernels with `faster-transformer` in token-forward stage
- [ ] Support all models
- [x] Llama
- [x] Llama-2
- [x] Bloom
- [ ] Chatglm2
- [x] Chatglm2
- [ ] Benchmarking for all models

## Get started
Expand Down Expand Up @@ -68,6 +70,12 @@ git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .

# also, install xformers from source:
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers

```

### Docker
Expand All @@ -89,7 +97,10 @@ git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .


# install xformers from source
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
```

### Dive into fast-inference!
Expand Down
54 changes: 37 additions & 17 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Tuple
import math

import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
Expand Down Expand Up @@ -35,6 +36,13 @@
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False

try:
from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand Down Expand Up @@ -209,7 +217,8 @@ def llama_model_forward(
hidden_states=all_hidden_states,
attentions=all_self_attns,
)



@staticmethod
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
Expand Down Expand Up @@ -253,6 +262,7 @@ def llama_decoder_layer_forward(
outputs += (present_key_value,)

return outputs


@staticmethod
def llama_flash_attn_kvcache_forward(
Expand Down Expand Up @@ -348,24 +358,34 @@ def llama_flash_attn_kvcache_forward(
infer_state.decode_mem_index,
infer_state.cache_manager,
)

# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)


if self.num_key_value_groups == 1:
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
if HAS_FLASH_KERNEL:
attn_output = torch.empty_like(query_states)
heads_per_group = self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]

query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)

attn_output = flash_attn_with_kvcache(q = query_states, k_cache = cache_k, v_cache = cache_v, softmax_scale = 1/ math.sqrt(self.head_dim), causal = True)

else:
attn_output = torch.empty_like(query_states)
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
else:
attn_output = torch.empty_like(query_states)
Llama2TokenAttentionForwards.token_attn(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
Expand Down
4 changes: 2 additions & 2 deletions examples/inference/bench_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def test_llama(args):
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
parser.add_argument("--input_len", type=int, default=256, help="Maximum input length")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Maximum batch size")
parser.add_argument("--input_len", type=int, default=128, help="Maximum input length")
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
parser.add_argument(
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
Expand Down