diff --git a/README.md b/README.md index f57c3f7862ed..9d3f742225ea 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) - InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) +- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 4019e0bbd90f..af4eb81646eb 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -66,7 +66,11 @@ Alongside each architecture, we include some popular models that use it. * - :code:`InternLM2ForCausalLM` - InternLM2 - :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc. - - + - + * - :code:`JAISLMHeadModel` + - Jais + - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. + - * - :code:`LlamaForCausalLM` - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index bc3b6a582d53..069830c4d7cb 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -27,6 +27,7 @@ "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 263727cac19f..e75dda750cb2 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -242,8 +242,7 @@ def sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, logits, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py new file mode 100644 index 000000000000..74c8e7f96302 --- /dev/null +++ b/vllm/model_executor/models/jais.py @@ -0,0 +1,351 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py +# Copyright 2023 The vLLM team. +# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights +# reserved. +# Copyright 2023 Cerebras Systems. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Jais model compatible with HuggingFace weights.""" + +import math +from typing import List, Optional, Tuple + +import torch +from torch import nn +from vllm.transformers_utils.configs import JAISConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, +) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) +from vllm.sequence import SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class SwiGLUActivation(nn.Module): + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return x1 * nn.functional.silu(x2) + + +def _get_alibi_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + +class JAISAttention(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + if hasattr(config, "scale_qk_dot_by_d"): + config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d + self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5 + self.scale = self.head_dim**-self.attn_scale_power + + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + bias=True, + linear_method=linear_method, + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + linear_method=linear_method, + ) + + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end] + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + ) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + key_cache, value_cache = kv_cache + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class JAISMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.swiglu = config.activation_function == "swiglu" + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) + self.c_fc2 = (ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) if self.swiglu else None) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + linear_method=linear_method, + ) + + self.act = SwiGLUActivation() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.swiglu: + hidden_states2, _ = self.c_fc2(hidden_states) + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = (self.act(hidden_states, hidden_states2) + if self.swiglu else self.act(hidden_states)) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class JAISBlock(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = JAISAttention(config, linear_method) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = JAISMLP(inner_dim, config, linear_method) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class JAISModel(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + assert not config.scale_attn_by_inverse_layer_idx + assert not config.reorder_and_upcast_attn + self.embed_dim = config.hidden_size + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.wpe = (nn.Embedding(config.max_position_embeddings, + self.embed_dim) + if config.position_embedding_type != "alibi" else None) + if hasattr(config, "embeddings_scale"): + self.embeddings_scale = config.embeddings_scale + else: + self.embeddings_scale = config.mup_embeddings_scale + self.h = nn.ModuleList([ + JAISBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + if self.wpe is not None: + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = inputs_embeds + hidden_states *= torch.tensor(float(self.embeddings_scale), + dtype=hidden_states.dtype) + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class JAISLMHeadModel(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.transformer = JAISModel(config, linear_method) + self.lm_head_weight = self.transformer.wte.weight + if hasattr(config, "width_scale"): + self.output_logits_scale = config.width_scale + else: + self.output_logits_scale = (config.mup_output_alpha * + config.mup_width_scale) + self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, + scale=self.output_logits_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if "relative_pe" in name: + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5e1f0439aec5..081e81768b23 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -10,6 +10,7 @@ "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "starcoder2": Starcoder2Config, + "jais": JAISConfig, } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 4966526f1518..150ee2ce97ad 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -5,10 +5,12 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config +from vllm.transformers_utils.configs.jais import JAISConfig __all__ = [ "ChatGLMConfig", "MPTConfig", "RWConfig", "Starcoder2Config", + "JAISConfig", ] diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py new file mode 100644 index 000000000000..94f438716f8b --- /dev/null +++ b/vllm/transformers_utils/configs/jais.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2023 Cerebras Systems. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAIS configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class JAISConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a + [`JAISModel`]. It is used to instantiate a JAIS model according to the + specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the JAIS model. Defines the number of different + tokens that can be represented by the + `inputs_ids` passed when calling [`JAISModel`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used + with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the + Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set + it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list + `["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in + the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, + defaults to `False`): + Whether to additionally scale attention weights by + `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention + (dot-product) + and upcast attention dot-product/softmax to float() when training + with mixed precision. + position_embedding_type (`str`, *optional*, defaults to `"learned"`): + Positional embedding can be either `"alibi"` or `"learned"`. + mup_width_scale (`float`, *optional*, defaults to 1.0): + muP parameter to scale learning rate and initializers. Calculated + as (`d_model,0 / d_model`), where + `d_model` is the model's width and `d_model,0` is the proxy + model's width. + mup_embeddings_scale (`float`, *optional*, defaults to 1.0): + muP parameter to scale token and position embeddings. + mup_output_alpha (`float`, *optional*, defaults to 1.0): + muP parameter to scale output logits + (`output_logits_scale = mup_output_alpha * mup_width_scale`). + mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`): + Scale attention weights by dividing by hidden_size instead of + sqrt(hidden_size). Need to set scale_attn_weights to `True` as + well. + alibi_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for ALiBi + embeddings. Currently only supports linear + scaling strategy. Can specify either the scaling `factor` (must be + a float greater than 1) for fixed scaling + or `train_seq_len` for dynamic scaling on input samples with + sequence length > `train_seq_len`. The expected + formats are `{"type": strategy name, "factor": scaling factor}` or + `{"type": strategy name, + "train_seq_len": training sequence length}`. + architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']): + architecture names for Jais. + + Example: + + ```python + >>> from transformers import JAISConfig, JAISModel + + >>> # Initializing a JAIS configuration + >>> configuration = JAISConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = JAISModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "jais" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1024, + n_embd=768, + n_layer=12, + n_head=12, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + position_embedding_type="learned", + mup_width_scale=1.0, + mup_embeddings_scale=1.0, + mup_output_alpha=1.0, + mup_scale_qk_dot_by_d=False, + alibi_scaling=None, + architectures=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.position_embedding_type = position_embedding_type + self.mup_width_scale = mup_width_scale + self.mup_embeddings_scale = mup_embeddings_scale + self.mup_output_alpha = mup_output_alpha + self.mup_scale_qk_dot_by_d = mup_scale_qk_dot_by_d + + self.alibi_scaling = alibi_scaling + self._alibi_scaling_validation() + if architectures is None: + architectures = ["JAISLMHeadModel"] + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + architectures=architectures, + **kwargs, + ) + + def _alibi_scaling_validation(self): + """ + Validate the `alibi_scaling` configuration. + """ + if self.alibi_scaling is None: + return + + if (not isinstance(self.alibi_scaling, dict) + or len(self.alibi_scaling) != 2): + raise ValueError( + "`alibi_scaling` must be a dictionary with two fields," + "`type` and `factor` or `type` and `train_seq_len`, " + f"got {self.alibi_scaling}") + alibi_scaling_type = self.alibi_scaling.get("type", None) + alibi_scaling_factor = self.alibi_scaling.get("factor", None) + alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) + if alibi_scaling_type is None or alibi_scaling_type != "linear": + raise ValueError(f"`alibi_scaling`'s type field must be 'linear'," + f"got {alibi_scaling_type}") + if (alibi_scaling_factor is not None + and not isinstance(alibi_scaling_factor, float) + or alibi_scaling_factor <= 1.0): + raise ValueError( + f"`alibi_scaling`'s factor field must be a float > 1.0," + f"got {alibi_scaling_factor}") + if (alibi_dynamic_scaling is not None + and not isinstance(alibi_dynamic_scaling, int) + or alibi_dynamic_scaling <= 1): + raise ValueError( + f"`alibi_scaling`'s `train_seq_len` field must be an" + f"integer > 1, got {alibi_dynamic_scaling}")