Skip to content

Commit 858d643

Browse files
authored
[None][fix] Fix ModelConfig.from_pretrained get quant config file (#8647)
Signed-off-by: Tailing Yuan <[email protected]>
1 parent cc5b8b6 commit 858d643

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def is_generation_model(model_architectures: Optional[List[str]],
268268
# once ModelType is used in pytorch flow.
269269

270270
@staticmethod
271-
def load_modelopt_quant_config(quant_config_file, model_dir, moe_backend):
271+
def load_modelopt_quant_config(quant_config_file, checkpoint_dir,
272+
moe_backend):
272273
quant_config = QuantConfig()
273274
layer_quant_config = None
274275

@@ -288,7 +289,8 @@ def load_modelopt_quant_config(quant_config_file, model_dir, moe_backend):
288289
'exclude_modules', None)
289290

290291
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
291-
mixed_quant_config_file = model_dir / 'quant_cfg.json'
292+
mixed_quant_config_file = transformers.utils.hub.cached_file(
293+
checkpoint_dir, 'quant_cfg.json')
292294
with open(mixed_quant_config_file) as fm:
293295
mixed_quant_configs = json.load(fm)
294296
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
@@ -475,31 +477,34 @@ def from_pretrained(cls,
475477
checkpoint_dir,
476478
trust_remote_code=trust_remote_code,
477479
)
478-
479-
# Find the cache path by looking for the config.json file which should be in all
480-
# huggingface models
481-
model_dir = Path(
482-
transformers.utils.hub.cached_file(checkpoint_dir,
483-
'config.json')).parent
484480
else:
485481
raise ValueError(
486482
"checkpoint_dir is None. Cannot load model config without a valid checkpoint directory."
487483
)
488484

485+
# Get cached file from path or repo id, return None if not exists.
486+
def cached_file(path_or_repo_id, file_name):
487+
try:
488+
return transformers.utils.hub.cached_file(
489+
path_or_repo_id, file_name)
490+
except OSError:
491+
return None
492+
489493
quant_config = QuantConfig()
490494
layer_quant_config = None
491495
moe_backend = kwargs.get('moe_backend', 'CUTLASS')
492496

493497
# quantized ckpt in modelopt format
494-
if (quant_config_file := model_dir / 'hf_quant_config.json').exists():
498+
if quant_config_file := cached_file(checkpoint_dir,
499+
'hf_quant_config.json'):
495500
quant_config, layer_quant_config = cls.load_modelopt_quant_config(
496-
quant_config_file, model_dir, moe_backend)
501+
quant_config_file, checkpoint_dir, moe_backend)
497502
# quantized ckpt in other formats
498503
elif hasattr(pretrained_config, "quantization_config"):
499504
hf_quant_config = pretrained_config.quantization_config
500505
quant_config, layer_quant_config = cls.load_hf_quant_config(
501506
hf_quant_config, moe_backend)
502-
elif (quant_config_file := model_dir / 'dtypes.json').exists():
507+
elif quant_config_file := cached_file(checkpoint_dir, 'dtypes.json'):
503508
quant_config, layer_quant_config = cls.load_quant_config_from_dtypes_json(
504509
quant_config_file, moe_backend)
505510

0 commit comments

Comments
 (0)