Skip to content
Closed

Jais #2591

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,42 @@ Easy, fast, and cheap LLM serving for everyone

---

<h2 align="left">
This is a fork of the vLLM repo to add support for Jais.
</h2>

## ** How to use **
1. Clone this repo: [vllm-jais](https://github.com/SamujjwalSam/vllm-jais)
2. Copy the file [jais.py](vllm/model_executor/models/jais.py) into [models](vllm/model_executor/models/) directory in you vLLM installation.
3. Update [__init__.py](vllm/model_executor/models/__init__.py)
4. Download the [Jais model](https://huggingface.co/core42/jais-30b-chat-v1) from HuggingFace
5. Update the `config.json` in the model directory, if required.
6. Run [main_jais.py](main_jais.py) file


## ** Limitations: **
- Tested only with `13b` and `30b` models
- Works only with vLLM "0.2.1-post1" tag
- `13b` can only be used on a single GPU due to non-divisibility of FF layer dim
- `30b` can only be used either on a single GPU or two GPUs due to non-divisibility of FF layer dim
- Need to modify the config.json file to add extra attributes


**NOTE:** You might need to modify the `config.json` file after downloading from [HuggingFace](https://huggingface.co/core42/jais-30b-chat-v1). The file will be located whereever the model weights are located.
I have added the config files for `Jais-13B` and `Jais-30B` at location: `configs/config_13B.json` and `configs/config_30B.json` respectively.
Replace the contents of the `config.json` with the corresponding copy.

For example, the following config might not be present in the `config.json` file:
```json
"architectures": [
"GPT2LMHeadModel"
],
```

"""

---

*Latest News* 🔥
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
Expand Down
66 changes: 66 additions & 0 deletions main_jais.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# coding=utf-8
# !/usr/bin/python3.8
"""
__synopsis__ : Added support for Jais into vLLM.
__description__ : This script contains code to run Jais model using vLLM.
__project__ :
__author__ : Samujjwal Ghosh <[email protected]>, Samta Kamboj <[email protected]>
__version__ : "0.1"
__date__ : "25 Jan, 2024"
"""

from vllm import LLM, SamplingParams

def load_model_vllm(model_path, dtype="float16", tensor_parallel_size=1,):
print(f'Loading model from path: [{model_path}]')
llm = LLM(model=model_path, trust_remote_code=True,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
# enforce_eager=True,
# gpu_memory_utilization=0.95,
swap_space=16,
# block_size=32,
)

return llm

def main(n_gpus=1, model_path='core42/jais-30b-chat-v1'):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The future of AI is",
]

# load the model
llm = load_model_vllm(model_path, tensor_parallel_size=n_gpus)

# set the params for generations
sampling_params = SamplingParams(
n=1,
temperature=0.7,
top_p=0.7,
max_tokens=200,
frequency_penalty=0.2,
presence_penalty=0.2,
)

# generate the outputs
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, \t\t Generated text: {generated_text!r}")


if __name__ == "__main__":
"""
NOTE:
1. Tested only with Jais `13b` and `30b` models
2. Works only with vLLM "0.2.1-post1" tag
3. `13b` can only be used on a single GPU due to non-divisibility of FF layer dim
4. `30b` can only be used either on a single GPU or two GPUs due to non-divisibility of FF layer dim
5. Need to modify the config.json file to add extra attributes
"""
main(1, model_path = 'core42/jais-30b-chat-v1')
2 changes: 1 addition & 1 deletion vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams

__version__ = "0.2.1"
__version__ = "0.2.1.post1"

__all__ = [
"LLM",
Expand Down
40 changes: 40 additions & 0 deletions vllm/configs/config_13B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"_name_or_path": "inception-mbzuai/jais-13b-chat",
"activation_function": "swiglu",
"architectures": [
"JAISLMHeadModel"
],
"attn_pdrop": 0.0,
"auto_map": {
"AutoConfig": "configuration_jais.JAISConfig",
"AutoModel": "modeling_jais.JAISModel",
"AutoModelForCausalLM": "modeling_jais.JAISLMHeadModel",
"AutoModelForQuestionAnswering": "modeling_jais.JAISForQuestionAnswering",
"AutoModelForSequenceClassification": "modeling_jais.JAISForSequenceClassification",
"AutoModelForTokenClassification": "modeling_jais.JAISForTokenClassification"
},
"bos_token_id": 0,
"embd_pdrop": 0.0,
"embeddings_scale": 14.6,
"eos_token_id": 0,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"model_type": "jais",
"n_embd": 5120,
"n_head": 40,
"n_inner": 13653,
"n_layer": 40,
"n_positions": 2048,
"pad_token_id": 0,
"position_embedding_type": "alibi",
"reorder_and_upcast_attn": false,
"resid_pdrop": 0.0,
"scale_attn_by_inverse_layer_idx": false,
"scale_attn_weights": true,
"scale_qk_dot_by_d": true,
"torch_dtype": "float32",
"transformers_version": "4.35.0.dev0",
"use_cache": true,
"vocab_size": 84992,
"width_scale": 0.11100000000000002
}
38 changes: 38 additions & 0 deletions vllm/configs/config_30B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"attn_pdrop": 0.0,
"scale_attn_weights": true,
"resid_pdrop": 0.0,
"mup_embeddings_scale": 1.0,
"n_inner": 19114,
"n_embd": 7168,
"layer_norm_epsilon": 1e-05,
"n_positions": 8192,
"activation_function": "swiglu",
"n_head": 56,
"n_layer": 48,
"mup_output_alpha": 2.22,
"mup_width_scale": 0.03571428571428571,
"position_embedding_type": "alibi",
"mup_scale_qk_dot_by_d": true,
"tie_word_embeddings": true,
"vocab_size": 84992,
"embd_pdrop": 0.0,
"model_type": "jais",
"use_cache": true,
"width_scale":0.07928571428571429,
"embeddings_scale":14.6,
"scale_qk_dot_by_d":false,
"auto_map": {
"AutoConfig": "configuration_jais.JAISConfig",
"AutoModel": "modeling_jais.JAISModel",
"AutoModelForCausalLM": "modeling_jais.JAISLMHeadModel",
"AutoModelForQuestionAnswering": "modeling_jais.JAISForQuestionAnswering",
"AutoModelForSequenceClassification": "modeling_jais.JAISForSequenceClassification",
"AutoModelForTokenClassification": "modeling_jais.JAISForTokenClassification"
},
"eos_token_id": 0,
"pad_token_id": 0,
"bos_token_id": 0,
"initializer_range": 0.02,
"architectures": ["JAISLMHeadModel"]
}
1 change: 1 addition & 0 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,
"JAISLMHeadModel": JAISLMHeadModel,
}

# FIXME(woosuk): Remove this once all models support quantization.
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
from vllm.model_executor.models.jais import JAISLMHeadModel

__all__ = [
"AquilaForCausalLM",
Expand All @@ -30,4 +31,5 @@
"OPTForCausalLM",
"QWenLMHeadModel",
"MistralForCausalLM",
"JAISLMHeadModel",
]
Loading