@@ -268,7 +268,8 @@ def is_generation_model(model_architectures: Optional[List[str]],
268268 # once ModelType is used in pytorch flow.
269269
270270 @staticmethod
271- def load_modelopt_quant_config (quant_config_file , model_dir , moe_backend ):
271+ def load_modelopt_quant_config (quant_config_file , checkpoint_dir ,
272+ moe_backend ):
272273 quant_config = QuantConfig ()
273274 layer_quant_config = None
274275
@@ -288,7 +289,8 @@ def load_modelopt_quant_config(quant_config_file, model_dir, moe_backend):
288289 'exclude_modules' , None )
289290
290291 if quant_config .quant_algo == QuantAlgo .MIXED_PRECISION :
291- mixed_quant_config_file = model_dir / 'quant_cfg.json'
292+ mixed_quant_config_file = transformers .utils .hub .cached_file (
293+ checkpoint_dir , 'quant_cfg.json' )
292294 with open (mixed_quant_config_file ) as fm :
293295 mixed_quant_configs = json .load (fm )
294296 # kv_cache_quant_algo is global regardless of MIXED_PRECISION
@@ -475,31 +477,34 @@ def from_pretrained(cls,
475477 checkpoint_dir ,
476478 trust_remote_code = trust_remote_code ,
477479 )
478-
479- # Find the cache path by looking for the config.json file which should be in all
480- # huggingface models
481- model_dir = Path (
482- transformers .utils .hub .cached_file (checkpoint_dir ,
483- 'config.json' )).parent
484480 else :
485481 raise ValueError (
486482 "checkpoint_dir is None. Cannot load model config without a valid checkpoint directory."
487483 )
488484
485+ # Get cached file from path or repo id, return None if not exists.
486+ def cached_file (path_or_repo_id , file_name ):
487+ try :
488+ return transformers .utils .hub .cached_file (
489+ path_or_repo_id , file_name )
490+ except OSError :
491+ return None
492+
489493 quant_config = QuantConfig ()
490494 layer_quant_config = None
491495 moe_backend = kwargs .get ('moe_backend' , 'CUTLASS' )
492496
493497 # quantized ckpt in modelopt format
494- if (quant_config_file := model_dir / 'hf_quant_config.json' ).exists ():
498+ if quant_config_file := cached_file (checkpoint_dir ,
499+ 'hf_quant_config.json' ):
495500 quant_config , layer_quant_config = cls .load_modelopt_quant_config (
496- quant_config_file , model_dir , moe_backend )
501+ quant_config_file , checkpoint_dir , moe_backend )
497502 # quantized ckpt in other formats
498503 elif hasattr (pretrained_config , "quantization_config" ):
499504 hf_quant_config = pretrained_config .quantization_config
500505 quant_config , layer_quant_config = cls .load_hf_quant_config (
501506 hf_quant_config , moe_backend )
502- elif ( quant_config_file := model_dir / 'dtypes.json' ). exists ( ):
507+ elif quant_config_file := cached_file ( checkpoint_dir , 'dtypes.json' ):
503508 quant_config , layer_quant_config = cls .load_quant_config_from_dtypes_json (
504509 quant_config_file , moe_backend )
505510
0 commit comments