| 
40 | 40 | device_woqlinear_mapping = {"cpu": INCWeightOnlyLinear, "hpu": HPUWeightOnlyLinear}  | 
41 | 41 | 
 
  | 
42 | 42 | 
 
  | 
43 |  | -def save(model, output_dir="./saved_results"):  | 
 | 43 | +def save(model, output_dir="./saved_results", format=LoadFormat.DEFAULT, **kwargs):  | 
44 | 44 |     """Save the quantized model and config to the output path.  | 
45 | 45 | 
  | 
46 | 46 |     Args:  | 
47 | 47 |         model (torch.nn.module): raw fp32 model or prepared model.  | 
48 | 48 |         output_dir (str, optional): output path to save.  | 
 | 49 | +        format (str, optional): The format in which to save the model. Options include "default" and "huggingface". Defaults to "default".  | 
 | 50 | +        kwargs: Additional arguments for specific formats. For example:  | 
 | 51 | +            - safe_serialization (bool): Whether to use safe serialization when saving (only applicable for 'huggingface' format). Defaults to True.  | 
 | 52 | +            - tokenizer (Tokenizer, optional): The tokenizer to be saved along with the model (only applicable for 'huggingface' format).  | 
 | 53 | +            - max_shard_size (str, optional): The maximum size for each shard (only applicable for 'huggingface' format). Defaults to "5GB".  | 
49 | 54 |     """  | 
50 | 55 |     os.makedirs(output_dir, exist_ok=True)  | 
 | 56 | +    if format == LoadFormat.HUGGINGFACE:  # pragma: no cover  | 
 | 57 | +        config = model.config  | 
 | 58 | +        quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None  | 
 | 59 | +        if "backend" in quantization_config and "auto_round" in quantization_config["backend"]:  | 
 | 60 | +            safe_serialization = kwargs.get("safe_serialization", True)  | 
 | 61 | +            tokenizer = kwargs.get("tokenizer", None)  | 
 | 62 | +            max_shard_size = kwargs.get("max_shard_size", "5GB")  | 
 | 63 | +            if tokenizer is not None:  | 
 | 64 | +                tokenizer.save_pretrained(output_dir)  | 
 | 65 | +            del model.save  | 
 | 66 | +            model.save_pretrained(output_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)  | 
 | 67 | +            return  | 
 | 68 | + | 
51 | 69 |     qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)  | 
52 | 70 |     qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)  | 
53 | 71 |     # saving process  | 
@@ -203,8 +221,15 @@ def load_hf_format_woq_model(self):  | 
203 | 221 | 
 
  | 
204 | 222 |         # get model class and config  | 
205 | 223 |         model_class, config = self._get_model_class_and_config()  | 
206 |  | -        self.quantization_config = config.quantization_config  | 
207 |  | - | 
 | 224 | +        self.quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None  | 
 | 225 | +        if (  | 
 | 226 | +            "backend" in self.quantization_config and "auto_round" in self.quantization_config["backend"]  | 
 | 227 | +        ):  # # pragma: no cover  | 
 | 228 | +            # load autoround format quantized model  | 
 | 229 | +            from auto_round import AutoRoundConfig  | 
 | 230 | + | 
 | 231 | +            model = model_class.from_pretrained(self.model_name_or_path)  | 
 | 232 | +            return model  | 
208 | 233 |         # get loaded state_dict  | 
209 | 234 |         self.loaded_state_dict = self._get_loaded_state_dict(config)  | 
210 | 235 |         self.loaded_state_dict_keys = list(set(self.loaded_state_dict.keys()))  | 
@@ -400,7 +425,7 @@ def _get_model_class_and_config(self):  | 
400 | 425 |         trust_remote_code = self.kwargs.pop("trust_remote_code", None)  | 
401 | 426 |         kwarg_attn_imp = self.kwargs.pop("attn_implementation", None)  | 
402 | 427 | 
 
  | 
403 |  | -        config = AutoConfig.from_pretrained(self.model_name_or_path)  | 
 | 428 | +        config = AutoConfig.from_pretrained(self.model_name_or_path, trust_remote_code=trust_remote_code)  | 
404 | 429 |         # quantization_config = config.quantization_config  | 
405 | 430 | 
 
  | 
406 | 431 |         if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:  # pragma: no cover  | 
 | 
0 commit comments