@@ -112,38 +112,58 @@ class ModelConfig:
112112 Defaults to 'auto' which defaults to 'hf'.
113113 mm_processor_kwargs: Arguments to be forwarded to the model's processor
114114 for multi-modal data, e.g., image processor.
115+ pooling_type: Used to configure the pooling method in the embedding
116+ model.
117+ pooling_norm: Used to determine whether to normalize the pooled
118+ data in the embedding model.
119+ pooling_softmax: Used to determine whether to softmax the pooled
120+ data in the embedding model.
121+ pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
122+ that the score corresponding to the pooling_step_tag_id in the
123+ generated sentence should be returned. Otherwise, it returns
124+ the scores for all tokens.
125+ pooling_returned_token_ids: pooling_returned_token_ids represents a
126+ list of indices for the vocabulary dimensions to be extracted,
127+ such as the token IDs of good_token and bad_token in the
128+ math-shepherd-mistral-7b-prm model.
115129 """
116130
117- def __init__ (self ,
118- model : str ,
119- task : Union [TaskOption , _Task ],
120- tokenizer : str ,
121- tokenizer_mode : str ,
122- trust_remote_code : bool ,
123- dtype : Union [str , torch .dtype ],
124- seed : int ,
125- revision : Optional [str ] = None ,
126- code_revision : Optional [str ] = None ,
127- rope_scaling : Optional [dict ] = None ,
128- rope_theta : Optional [float ] = None ,
129- tokenizer_revision : Optional [str ] = None ,
130- max_model_len : Optional [int ] = None ,
131- spec_target_max_model_len : Optional [int ] = None ,
132- quantization : Optional [str ] = None ,
133- quantization_param_path : Optional [str ] = None ,
134- enforce_eager : Optional [bool ] = None ,
135- max_context_len_to_capture : Optional [int ] = None ,
136- max_seq_len_to_capture : Optional [int ] = None ,
137- max_logprobs : int = 20 ,
138- disable_sliding_window : bool = False ,
139- skip_tokenizer_init : bool = False ,
140- served_model_name : Optional [Union [str , List [str ]]] = None ,
141- limit_mm_per_prompt : Optional [Mapping [str , int ]] = None ,
142- use_async_output_proc : bool = True ,
143- override_neuron_config : Optional [Dict [str , Any ]] = None ,
144- config_format : ConfigFormat = ConfigFormat .AUTO ,
145- chat_template_text_format : str = "string" ,
146- mm_processor_kwargs : Optional [Dict [str , Any ]] = None ) -> None :
131+ def __init__ (
132+ self ,
133+ model : str ,
134+ task : Union [TaskOption , _Task ],
135+ tokenizer : str ,
136+ tokenizer_mode : str ,
137+ trust_remote_code : bool ,
138+ dtype : Union [str , torch .dtype ],
139+ seed : int ,
140+ revision : Optional [str ] = None ,
141+ code_revision : Optional [str ] = None ,
142+ rope_scaling : Optional [dict ] = None ,
143+ rope_theta : Optional [float ] = None ,
144+ tokenizer_revision : Optional [str ] = None ,
145+ max_model_len : Optional [int ] = None ,
146+ spec_target_max_model_len : Optional [int ] = None ,
147+ quantization : Optional [str ] = None ,
148+ quantization_param_path : Optional [str ] = None ,
149+ enforce_eager : Optional [bool ] = None ,
150+ max_context_len_to_capture : Optional [int ] = None ,
151+ max_seq_len_to_capture : Optional [int ] = None ,
152+ max_logprobs : int = 20 ,
153+ disable_sliding_window : bool = False ,
154+ skip_tokenizer_init : bool = False ,
155+ served_model_name : Optional [Union [str , List [str ]]] = None ,
156+ limit_mm_per_prompt : Optional [Mapping [str , int ]] = None ,
157+ use_async_output_proc : bool = True ,
158+ override_neuron_config : Optional [Dict [str , Any ]] = None ,
159+ config_format : ConfigFormat = ConfigFormat .AUTO ,
160+ chat_template_text_format : str = "string" ,
161+ mm_processor_kwargs : Optional [Dict [str , Any ]] = None ,
162+ pooling_type : Optional [str ] = None ,
163+ pooling_norm : Optional [bool ] = None ,
164+ pooling_softmax : Optional [bool ] = None ,
165+ pooling_step_tag_id : Optional [int ] = None ,
166+ pooling_returned_token_ids : Optional [List [int ]] = None ) -> None :
147167 self .model = model
148168 self .tokenizer = tokenizer
149169 self .tokenizer_mode = tokenizer_mode
@@ -224,6 +244,13 @@ def __init__(self,
224244 supported_tasks , task = self ._resolve_task (task , self .hf_config )
225245 self .supported_tasks = supported_tasks
226246 self .task : Final = task
247+ self .pooler_config = self ._init_pooler_config (
248+ pooling_type ,
249+ pooling_norm ,
250+ pooling_softmax ,
251+ pooling_step_tag_id ,
252+ pooling_returned_token_ids ,
253+ )
227254
228255 self ._verify_quantization ()
229256 self ._verify_cuda_graph ()
@@ -242,6 +269,23 @@ def _init_multimodal_config(
242269
243270 return None
244271
272+ def _init_pooler_config (
273+ self ,
274+ pooling_type : Optional [str ] = None ,
275+ pooling_norm : Optional [bool ] = None ,
276+ pooling_softmax : Optional [bool ] = None ,
277+ pooling_step_tag_id : Optional [int ] = None ,
278+ pooling_returned_token_ids : Optional [List [int ]] = None
279+ ) -> Optional ["PoolerConfig" ]:
280+ if self .task == "embedding" :
281+ return PoolerConfig (
282+ pooling_type = pooling_type ,
283+ pooling_norm = pooling_norm ,
284+ pooling_softmax = pooling_softmax ,
285+ pooling_step_tag_id = pooling_step_tag_id ,
286+ pooling_returned_token_ids = pooling_returned_token_ids )
287+ return None
288+
245289 def _init_attention_free (self ) -> bool :
246290 architectures = getattr (self .hf_config , "architectures" , [])
247291 return ModelRegistry .is_attention_free_model (architectures )
@@ -1660,6 +1704,17 @@ class MultiModalConfig:
16601704 # TODO: Add configs to init vision tower or not.
16611705
16621706
1707+ @dataclass
1708+ class PoolerConfig :
1709+ """Controls the behavior of pooler in embedding model"""
1710+
1711+ pooling_type : Optional [str ] = None
1712+ pooling_norm : Optional [bool ] = None
1713+ pooling_softmax : Optional [bool ] = None
1714+ pooling_step_tag_id : Optional [int ] = None
1715+ pooling_returned_token_ids : Optional [List [int ]] = None
1716+
1717+
16631718_STR_DTYPE_TO_TORCH_DTYPE = {
16641719 "half" : torch .float16 ,
16651720 "float16" : torch .float16 ,
0 commit comments