Skip to content

Need help with supporting "core42/jais-13b-chat" model #1808

@sam-iink

Description

@sam-iink

Hello Team,

I am attempting to add support for "core42/jais-13b-chat" model for vLLM. I have completed most of the required changes except for AliBi embeddings.

This is how the model looks like if loaded with HF:

JAISLMHeadModel(
  (transformer): JAISModel(
    (wte): Embedding(84992, 5120)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-39): 40 x JAISBlock(
        (ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (attn): JAISAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (mlp): JAISMLP(
          (c_fc): Conv1D()
          (c_fc2): Conv1D()
          (c_proj): Conv1D()
          (act): SwiGLUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
    (relative_pe): AlibiPositionEmbeddingLayer()
  )
  (lm_head): Linear(in_features=5120, out_features=84992, bias=False)
)

and this is how it looks after I made the necessary changes for vLLM 0.2.1-post1:

JAISLMHeadModel(                                                                        
  (transformer): JAISModel(
    (wte): VocabParallelEmbedding()                                         
    (h): ModuleList(                                                                    
      (0-39): 40 x JAISBlock(                                                           
        (ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)                  
        (attn): JAISAttention(                                              
          (c_attn): ColumnParallelLinear()                                  
          (c_proj): RowParallelLinear()                                                 
          (attn): PagedAttentionWithALiBi()
        )                                                                               
        (ln_2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)           
        (mlp): JAISMLP(                                                                 
          (c_fc): ColumnParallelLinear()                                                
          (c_fc2): ColumnParallelLinear()
          (c_proj): RowParallelLinear()
          (act): SiluAndMul()                                               
        )                                                                               
      )                                                                                 
    )                                                                                   
    (ln_f): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
  )                                                                                     
  (sampler): Sampler()                                                                  
)

However, the checkpoint loading does not work due to (relative_pe): AlibiPositionEmbeddingLayer().
I don't know how to make the corresponding changes in the class definition.

Can someone please help me on this?

Jais class definition: https://huggingface.co/core42/jais-13b-chat/blob/main/modeling_jais.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions