diff --git a/README.md b/README.md
index f62357b3b3..f3012749f2 100644
--- a/README.md
+++ b/README.md
@@ -55,7 +55,8 @@ You can contact us and communicate with us by adding our group:
|
## 🎉 News
-- 🔥2024.09.19: Supports the qwen2.5, qwen2.5-math, and qwen2.5-coder series models. Supports the qwen2-vl-72b series models.
+- 2024.09.23: Support for training and deploying pixtral-12b. Experience it using `swift infer --model_type pixtral-12b --dtype fp16`.
+- 🔥2024.09.19: Supports the qwen2.5, qwen2.5-math, and qwen2.5-coder series models. Supports the qwen2-vl-72b series models. Best practices can be found [here](https://github.com/modelscope/ms-swift/issues/2064).
- 2024.09.07: Support the `Reflection-llama3-70b` model, use by `swift sft/infer --model_type reflection-llama_3_1-70b`.
- 2024.09.06: Support fine-tuning and inference for mplug-owl3. Best practices can be found [here](https://github.com/modelscope/ms-swift/issues/1969).
- 2024.09.05: Support for the minicpm3-4b model. Experience it using `swift infer --model_type minicpm3-4b`.
diff --git a/README_CN.md b/README_CN.md
index c8d0f3989c..cbf75cbcdf 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -56,7 +56,8 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站:
## 🎉 新闻
-- 🔥2024.09.19: 支持qwen2.5、qwen2.5-math、qwen2.5-coder系列模型. 支持qwen2-vl-72b系列模型.
+- 2024.09.23: 支持pixtral-12b的训练与部署. 使用`swift infer --model_type pixtral-12b --dtype fp16`进行体验.
+- 🔥2024.09.19: 支持qwen2.5、qwen2.5-math、qwen2.5-coder系列模型. 支持qwen2-vl-72b系列模型. 最佳实践可以查看[这里](https://github.com/modelscope/ms-swift/issues/2064).
- 2024.09.07: 支持`Reflection-llama3-70b`模型, 使用`swift sft/infer --model_type reflection-llama_3_1-70b`命令即可训练和推理.
- 2024.09.06: 支持mplug-owl3的微调和推理, 最佳实践可以查看[这里](https://github.com/modelscope/ms-swift/issues/1969).
- 2024.09.05: 支持minicpm3-4b模型. 使用`swift infer --model_type minicpm3-4b`进行体验.
diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md"
index 9d76a72329..5920cd407f 100644
--- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md"
+++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md"
@@ -114,7 +114,7 @@
- `--test_oom_error`: 用于检测训练是否会发生OOM, 默认为`False`. 如果设置为True, 则会将训练集按max_length倒序进行排列, 方便OOM的测试. 该参数一般用于测试, 请谨慎设置.
- `--disable_tqdm`: 是否不启用tqdm, 这在`nohup`启动脚本时很有用. 默认为`False`, 即为启动tqdm.
- `--🔥lazy_tokenize`: 如果设置为False, 则在`trainer.train()`之前提前对所有文本进行预处理. 如果设置为True, 则延迟对文本进行编码, 减少预处理的等待并减少内存占用, 这在处理大数据集时很有用. 默认为`None`, 即我们会根据template的类型进行智能选择, LLM的模型通常设置为False, 多模态的模型通常设置为True(避免图片和音频加载导致过多的内存占用).
-- `--🔥preprocess_num_proc`: 在对数据集预处理时(对文本进行tokenize), 使用多进程. 默认为`1`. 与`lazy_tokenize`命令行参数一样, 用于解决预处理速度慢的问题. 但该策略无法减少内存占用, 所以如果当数据集巨大时, 建议使用`lazy_tokenize`. 推荐设置的值: 4, 8. 请注意: 当使用qwen-audio时, 该参数会强制设置为1, 因为qwen-audio的预处理函数中使用了torch的多进程, 会造成不兼容问题.
+- `--🔥preprocess_num_proc`: 在对数据集预处理时(对文本进行tokenize), 使用多进程. 默认为`1`. 与`lazy_tokenize`命令行参数一样, 用于解决预处理速度慢的问题. 但该策略无法减少内存占用, 所以如果当数据集巨大时, 建议使用`lazy_tokenize`. 推荐设置的值: 4, 8.
- `--🔥use_flash_attn`: 是否使用flash attn, 默认为`None`. 安装flash_attn的步骤可以查看[https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). 支持flash_attn的模型可以查看[LLM支持的模型](支持的模型和数据集.md#模型).
- `--ignore_args_error`: 是否忽略命令行传参错误抛出的Error, 默认为`False`. 如果需要拷贝代码到notebook中运行, 需要设置成True.
- `--🔥check_model_is_latest`: 检查模型是否是最新, 默认为`True`. 如果你需要断网进行训练, 请将该参数设置为`False`.
diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md"
index 8d3f0daede..6ef558f4b7 100644
--- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md"
+++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md"
@@ -492,6 +492,7 @@
|minicpm-v-v2-chat|[OpenBMB/MiniCPM-V-2](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v|✔|✘|✘|✘|timm, transformers<4.42|vision|[openbmb/MiniCPM-V-2](https://huggingface.co/openbmb/MiniCPM-V-2)|
|minicpm-v-v2_5-chat|[OpenBMB/MiniCPM-Llama3-V-2_5](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_5|✔|✔|✘|✘|timm, transformers>=4.36|vision|[openbmb/MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5)|
|minicpm-v-v2_6-chat|[OpenBMB/MiniCPM-V-2_6](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_6|✔|✔|✘|✘|timm, transformers>=4.36|vision, video|[openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6)|
+|pixtral-12b|[AI-ModelScope/pixtral-12b](https://modelscope.cn/models/AI-ModelScope/pixtral-12b/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|pixtral|✘|✘|✘|✘|transformers>=4.45.0.dev0|vision|[mistral-community/pixtral-12b](https://huggingface.co/mistral-community/pixtral-12b)|
|mplug-owl2-chat|[iic/mPLUG-Owl2](https://modelscope.cn/models/iic/mPLUG-Owl2/summary)|q_proj, k_proj.multiway.0, k_proj.multiway.1, v_proj.multiway.0, v_proj.multiway.1|mplug-owl2|✔|✘|✘|✘|transformers<4.35, icecream|vision|[MAGAer13/mplug-owl2-llama2-7b](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b)|
|mplug-owl2_1-chat|[iic/mPLUG-Owl2.1](https://modelscope.cn/models/iic/mPLUG-Owl2.1/summary)|c_attn.multiway.0, c_attn.multiway.1|mplug-owl2|✔|✘|✘|✘|transformers<4.35, icecream|vision|[Mizukiluke/mplug_owl_2_1](https://huggingface.co/Mizukiluke/mplug_owl_2_1)|
|mplug-owl3-7b-chat|[iic/mPLUG-Owl3-7B-240728](https://modelscope.cn/models/iic/mPLUG-Owl3-7B-240728/summary)|^(language_model\|vision2text_model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|mplug_owl3|✔|✘|✘|✘|transformers>=4.36, icecream|vision|[mPLUG/mPLUG-Owl3-7B-240728](https://huggingface.co/mPLUG/mPLUG-Owl3-7B-240728)|
diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md
index 57f91621f2..5003b19946 100644
--- a/docs/source_en/Instruction/Command-line-parameters.md
+++ b/docs/source_en/Instruction/Command-line-parameters.md
@@ -114,7 +114,7 @@
- `--test_oom_error`: Used to detect whether training will cause OOM, default is `False`. If set to True, will sort the training set in descending order by max_length, easy for OOM testing. This parameter is generally used for testing, use carefully.
- `--disable_tqdm`: Whether to disable tqdm, useful when launching script with `nohup`. Default is `False`, i.e. enable tqdm.
- `--🔥lazy_tokenize`: If set to False, preprocess all text before `trainer.train()`. If set to True, delay encoding text, reducing preprocessing wait and memory usage, useful when processing large datasets. Default is `None`, i.e. we intelligently choose based on template type, usually set to False for LLM models, set to True for multimodal models (to avoid excessive memory usage from loading images and audio).
-- `--🔥preprocess_num_proc`: Use multiprocessing when preprocessing dataset (tokenizing text). Default is `1`. Same as `lazy_tokenize` command line argument, used to solve slow preprocessing issue. But this strategy cannot reduce memory usage, so if dataset is huge, `lazy_tokenize` is recommended. Recommended values: 4, 8. Note: When using qwen-audio, this parameter will be forced to 1, because qwen-audio's preprocessing function uses torch's multiprocessing, which will cause compatibility issues.
+- `--🔥preprocess_num_proc`: Use multiprocessing when preprocessing dataset (tokenizing text). Default is `1`. Same as `lazy_tokenize` command line argument, used to solve slow preprocessing issue. But this strategy cannot reduce memory usage, so if dataset is huge, `lazy_tokenize` is recommended. Recommended values: 4, 8.
- `--🔥use_flash_attn`: Whether to use flash attn, default is `None`. Installation steps for flash_attn can be found at [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Models supporting flash_attn can be found in [LLM Supported Models](Supported-models-datasets.md).
- `--ignore_args_error`: Whether to ignore Error thrown by command line parameter errors, default is `False`. Set to True if need to copy code to notebook to run.
- `--🔥check_model_is_latest`: Check if model is latest, default is `True`. Set this to `False` if you need to train offline.
diff --git a/docs/source_en/Instruction/Supported-models-datasets.md b/docs/source_en/Instruction/Supported-models-datasets.md
index 397b982746..51bd2fccb6 100644
--- a/docs/source_en/Instruction/Supported-models-datasets.md
+++ b/docs/source_en/Instruction/Supported-models-datasets.md
@@ -492,6 +492,7 @@ The table below introcudes all models supported by SWIFT:
|minicpm-v-v2-chat|[OpenBMB/MiniCPM-V-2](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v|✔|✘|✘|✘|timm, transformers<4.42|vision|[openbmb/MiniCPM-V-2](https://huggingface.co/openbmb/MiniCPM-V-2)|
|minicpm-v-v2_5-chat|[OpenBMB/MiniCPM-Llama3-V-2_5](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_5|✔|✔|✘|✘|timm, transformers>=4.36|vision|[openbmb/MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5)|
|minicpm-v-v2_6-chat|[OpenBMB/MiniCPM-V-2_6](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_6|✔|✔|✘|✘|timm, transformers>=4.36|vision, video|[openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6)|
+|pixtral-12b|[AI-ModelScope/pixtral-12b](https://modelscope.cn/models/AI-ModelScope/pixtral-12b/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|pixtral|✘|✘|✘|✘|transformers>=4.45.0.dev0|vision|[mistral-community/pixtral-12b](https://huggingface.co/mistral-community/pixtral-12b)|
|mplug-owl2-chat|[iic/mPLUG-Owl2](https://modelscope.cn/models/iic/mPLUG-Owl2/summary)|q_proj, k_proj.multiway.0, k_proj.multiway.1, v_proj.multiway.0, v_proj.multiway.1|mplug-owl2|✔|✘|✘|✘|transformers<4.35, icecream|vision|[MAGAer13/mplug-owl2-llama2-7b](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b)|
|mplug-owl2_1-chat|[iic/mPLUG-Owl2.1](https://modelscope.cn/models/iic/mPLUG-Owl2.1/summary)|c_attn.multiway.0, c_attn.multiway.1|mplug-owl2|✔|✘|✘|✘|transformers<4.35, icecream|vision|[Mizukiluke/mplug_owl_2_1](https://huggingface.co/Mizukiluke/mplug_owl_2_1)|
|mplug-owl3-7b-chat|[iic/mPLUG-Owl3-7B-240728](https://modelscope.cn/models/iic/mPLUG-Owl3-7B-240728/summary)|^(language_model\|vision2text_model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|mplug_owl3|✔|✘|✘|✘|transformers>=4.36, icecream|vision|[mPLUG/mPLUG-Owl3-7B-240728](https://huggingface.co/mPLUG/mPLUG-Owl3-7B-240728)|
diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py
index 90feb3b1cd..1f5855f6b9 100644
--- a/swift/llm/utils/argument.py
+++ b/swift/llm/utils/argument.py
@@ -920,7 +920,7 @@ def _prepare_target_modules(self, target_modules) -> Union[List[str], str]:
target_modules.append('DEFAULT')
if 'DEFAULT' in target_modules:
target_modules.remove('DEFAULT')
- default_lora_tm = get_default_lora_target_modules(self.model_type)
+ default_lora_tm = get_default_lora_target_modules(self.model_type) or []
if isinstance(default_lora_tm, str):
return default_lora_tm
target_modules += default_lora_tm
diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py
index 28c9c90ae3..77fa5ad412 100644
--- a/swift/llm/utils/model.py
+++ b/swift/llm/utils/model.py
@@ -489,6 +489,8 @@ class ModelType:
mixtral_moe_7b_instruct = 'mixtral-moe-7b-instruct'
mixtral_moe_7b_aqlm_2bit_1x16 = 'mixtral-moe-7b-aqlm-2bit-1x16' # aqlm
mixtral_moe_8x22b_v1 = 'mixtral-moe-8x22b-v1'
+
+ pixtral_12b = 'pixtral-12b'
# wizardlm
wizardlm2_7b_awq = 'wizardlm2-7b-awq'
wizardlm2_8x22b = 'wizardlm2-8x22b'
@@ -1013,6 +1015,26 @@ def _output_device_map_hook(module, input, output):
return output.to(input[0].device)
+@register_model(
+ ModelType.pixtral_12b,
+ 'AI-ModelScope/pixtral-12b',
+ LoRATM.llava,
+ TemplateType.pixtral,
+ # torch_dtype=torch.float16, # Please do not use bf16.
+ requires=['transformers>=4.45.0.dev0'],
+ placeholder_tokens=['[IMG]'],
+ tags=['multi-modal', 'vision'],
+ hf_model_id='mistral-community/pixtral-12b')
+def get_model_tokenizer_pixtral(model_dir: str, *args, **kwargs):
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
+ processor = AutoProcessor.from_pretrained(model_dir)
+ kwargs['automodel_class'] = LlavaForConditionalGeneration
+ kwargs['tokenizer'] = processor.tokenizer
+ model, tokenizer = get_model_tokenizer_from_repo(model_dir, *args, **kwargs)
+ tokenizer.processor = processor
+ return model, tokenizer
+
+
@register_model(
ModelType.cogvlm2_video_13b_chat,
'ZhipuAI/cogvlm2-video-llama3-chat',
@@ -4452,7 +4474,16 @@ def get_model_tokenizer_internvl(model_dir: str,
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
use_flash_attn = kwargs.pop('use_flash_attn', False)
- model_config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
+ if hasattr(model_config.llm_config, 'attn_implementation'):
+ attr = 'attn_implementation'
+ else:
+ attr = '_attn_implementation'
+ if use_flash_attn:
+ setattr(model_config.llm_config, attr, 'flash_attention_2')
+ else:
+ setattr(model_config.llm_config, attr, 'eager')
+ setattr(model_config.llm_config, f'{attr}_internal', None)
+
model_quant_config = getattr(model_config, 'quantization_config', None)
use_bnb = False
diff --git a/swift/llm/utils/preprocess.py b/swift/llm/utils/preprocess.py
index bbff06d623..22ec11f9cd 100644
--- a/swift/llm/utils/preprocess.py
+++ b/swift/llm/utils/preprocess.py
@@ -41,6 +41,7 @@ def new_call_func(self, dataset: DATASET_TYPE) -> DATASET_TYPE:
self.shared_shm_name = shm.name
buffer = shm.buf
self.column_state = np.ndarray((len(self.key_mapping), ), dtype=np.bool_, buffer=buffer)
+ self.column_state[:] = 0
dataset = call_func(self, dataset)
if isinstance(dataset, HfIterableDataset) and dataset.features is None:
features = next(iter(dataset)).keys()
diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py
index 2d573648e7..2e0252659f 100644
--- a/swift/llm/utils/template.py
+++ b/swift/llm/utils/template.py
@@ -82,6 +82,7 @@ class TemplateType:
idefics3 = 'idefics3'
mistral_nemo = 'mistral-nemo'
+ pixtral = 'pixtral'
openbuddy = 'openbuddy'
openbuddy2 = 'openbuddy2'
internlm = 'internlm'
@@ -1530,6 +1531,69 @@ class Qwen2VLGenerationTemplate(_Qwen2VLTemplateMixin, DefaultGenerationTemplate
register_template(TemplateType.qwen2_vl_generation, Qwen2VLGenerationTemplate(), lazy_tokenize=True, is_generation=True)
+def _gather_list(batch: List[Dict[str, Any]], attr_name: str) -> Optional[List[Any]]:
+ # List[Tensor] -> List[Tensor]
+ res = []
+ for b in batch:
+ if b.get(attr_name) is not None:
+ res += b.pop(attr_name)
+ return res
+
+
+class PixtralTemplate(Template):
+
+ def __init__(self):
+ super().__init__(['{{SYSTEM}}'], ['[INST]{{QUERY}}[/INST]'], [''], [''], None)
+
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
+ example: Dict[str, Any]) -> List[Context]:
+ return ['[IMG]']
+
+ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ inputs, _ = super()._encode(example)
+ if len(inputs) == 0:
+ return inputs, {}
+ processor = self.tokenizer.processor
+ images = example['images']
+ input_ids = inputs['input_ids']
+ labels = inputs['labels']
+ idx_list = _findall(input_ids, 10)
+ if idx_list:
+ image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
+ inputs['pixel_values'] = image_inputs['pixel_values'][0]
+ image_sizes = image_inputs['image_sizes'][0]
+ added_tokens_len = 0
+ for idx, image_size in zip(idx_list, image_sizes):
+ height, width = image_size
+ num_height_tokens = height // processor.patch_size
+ num_width_tokens = width // processor.patch_size
+ replace_tokens = [processor.image_token * num_width_tokens + processor.image_break_token] * (
+ num_height_tokens - 1)
+ replace_tokens += [processor.image_token * num_width_tokens + processor.image_end_token]
+ # Flatten list
+ replace_str = ''.join(replace_tokens)
+ img_tokens: List[int] = self.tokenizer.encode(replace_str, add_special_tokens=False)
+ input_ids = input_ids[:idx + added_tokens_len] + img_tokens + input_ids[idx + added_tokens_len + 1:]
+ if labels is not None:
+ labels = labels[:idx + added_tokens_len] + [-100] * len(img_tokens) + labels[idx + added_tokens_len
+ + 1:]
+ added_tokens_len += len(img_tokens) - 1
+ inputs['input_ids'] = input_ids
+ inputs['labels'] = labels
+
+ return inputs, {}
+
+ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
+ pixel_values = _gather_list(batch, 'pixel_values')
+ res = super().data_collator(batch, padding_to)
+ if pixel_values:
+ res['pixel_values'] = pixel_values
+ return res
+
+
+register_template(TemplateType.pixtral, PixtralTemplate(), lazy_tokenize=True)
+
+
class YiCoderTemplate(ChatmlTemplate):
system = 'You are a helpful assistant.'
diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py
index b14ba5282f..548c15c534 100644
--- a/swift/llm/utils/utils.py
+++ b/swift/llm/utils/utils.py
@@ -473,6 +473,7 @@ def dynamic_vit_gradient_checkpointing(model, model_type: str) -> None:
if module_list is None:
continue
_add_gradient_checkpointing(module_list)
+ logger.info(f'Automatically add gradient_checkpointing to {vision_tower.__class__}.')
def find_embedding(model: Module) -> List[str]: