1414from cacheflow .model_executor .weight_utils import initialize_dummy_weights
1515
1616
17- _MODELS = {
18- 'gpt2' : GPT2LMHeadModel ,
19- 'llama' : LlamaForCausalLM ,
20- 'opt' : OPTForCausalLM ,
21- 'stablelm' : GPTNeoXForCausalLM ,
22- 'pythia' : GPTNeoXForCausalLM ,
23- 'dolly-v2' : GPTNeoXForCausalLM ,
17+ # TODO(woosuk): Lazy-load the model classes.
18+ _MODEL_REGISTRY = {
19+ "GPT2LMHeadModel" : GPT2LMHeadModel ,
20+ "GPTNeoXForCausalLM" : GPTNeoXForCausalLM ,
21+ "LlamaForCausalLM" : LlamaForCausalLM ,
22+ "OPTForCausalLM" : OPTForCausalLM ,
2423}
2524
2625_MEMORY_ANALYZERS = {
27- 'gpt2' : GPT2MemoryAnalyzer ,
28- 'llama' : LlamaMemoryAnalyzer ,
29- 'opt' : OPTMemoryAnalyzer ,
30- 'stablelm' : GPTNeoXMemoryAnalyzer ,
31- 'pythia' : GPTNeoXMemoryAnalyzer ,
32- 'dolly-v2' : GPTNeoXMemoryAnalyzer ,
26+ "GPT2LMHeadModel" : GPT2MemoryAnalyzer ,
27+ "GPTNeoXForCausalLM" : GPTNeoXMemoryAnalyzer ,
28+ "LlamaForCausalLM" : LlamaMemoryAnalyzer ,
29+ "OPTForCausalLM" : OPTMemoryAnalyzer ,
3330}
3431
3532
33+ def _get_model_architecture (config : PretrainedConfig ) -> nn .Module :
34+ architectures = getattr (config , "architectures" , [])
35+ for arch in architectures :
36+ if arch in _MODEL_REGISTRY :
37+ return _MODEL_REGISTRY [arch ]
38+ raise ValueError (
39+ f"Model architectures { architectures } are not supported for now. "
40+ f"Supported architectures: { list (_MODEL_REGISTRY .keys ())} "
41+ )
42+
43+
44+ def _get_memory_analyzer (config : PretrainedConfig ) -> CacheFlowMemoryAnalyzer :
45+ architectures = getattr (config , "architectures" , [])
46+ for arch in architectures :
47+ if arch in _MEMORY_ANALYZERS :
48+ return _MEMORY_ANALYZERS [arch ]
49+ raise ValueError (
50+ f"Model architectures { architectures } are not supported for now. "
51+ f"Supported architectures: { list (_MEMORY_ANALYZERS .keys ())} "
52+ )
53+
54+
3655def _get_dtype (config : PretrainedConfig , dtype : str ) -> torch .dtype :
37- # NOTE: getattr(config, ' torch_dtype' , torch.float32) is not correct
56+ # NOTE: getattr(config, " torch_dtype" , torch.float32) is not correct
3857 # because config.torch_dtype can be None.
39- config_dtype = getattr (config , ' torch_dtype' , None )
58+ config_dtype = getattr (config , " torch_dtype" , None )
4059 if config_dtype is None :
4160 config_dtype = torch .float32
42- if dtype == ' default' :
61+ if dtype == " default" :
4362 if config_dtype == torch .float32 :
4463 # Following the common practice, we use float16 for float32 models.
4564 torch_dtype = torch .float16
@@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
5170 # TODO(woosuk): Allow using float16 for bfloat16 models and
5271 # vice versa. Print a warning message and continue.
5372 raise ValueError (
54- f' Cannot use { torch_dtype } for { config_dtype } model.' )
73+ f" Cannot use { torch_dtype } for { config_dtype } model." )
5574 return torch_dtype
5675
5776
@@ -65,24 +84,21 @@ def get_model(
6584 config = AutoConfig .from_pretrained (model_name )
6685 torch_dtype = _get_dtype (config , dtype )
6786 torch .set_default_dtype (torch_dtype )
68- for model_class_name , model_class in _MODELS .items ():
69- if model_class_name in model_name :
70- if use_dummy_weights :
71- # Create a model instance.
72- # The weights will be initialized as empty tensors.
73- model = model_class (config )
74- model = model .cuda ()
75- # NOTE(woosuk): For precise performance evaluation, we assign
76- # random values to the weights.
77- initialize_dummy_weights (model )
78- else :
79- # Create a model instance.
80- model = model_class (config )
81- # Load the weights from the cached or downloaded files.
82- model .load_weights (model_name , cache_dir , use_np_cache )
83- model = model .cuda ()
84- return model .eval (), torch_dtype
85- raise ValueError (f'Unsupported model name: { model_name } ' )
87+ model_class = _get_model_architecture (config )
88+
89+ # Create a model instance.
90+ # The weights will be initialized as empty tensors.
91+ model = model_class (config )
92+ if use_dummy_weights :
93+ model = model .cuda ()
94+ # NOTE(woosuk): For accurate performance evaluation, we assign
95+ # random values to the weights.
96+ initialize_dummy_weights (model )
97+ else :
98+ # Load the weights from the cached or downloaded files.
99+ model .load_weights (model_name , cache_dir , use_np_cache )
100+ model = model .cuda ()
101+ return model .eval (), torch_dtype
86102
87103
88104def get_memory_analyzer (
@@ -95,9 +111,7 @@ def get_memory_analyzer(
95111) -> CacheFlowMemoryAnalyzer :
96112 config = AutoConfig .from_pretrained (model_name )
97113 torch_dtype = _get_dtype (config , dtype )
98- for model_class , memory_analyzer in _MEMORY_ANALYZERS .items ():
99- if model_class in model_name :
100- return memory_analyzer (
101- model_name , block_size , torch_dtype , gpu_memory , cpu_memory ,
102- tensor_parallel_size )
103- raise ValueError (f'Unsupported model name: { model_name } ' )
114+ memory_analyzer = _get_memory_analyzer (config )
115+ return memory_analyzer (
116+ model_name , block_size , torch_dtype , gpu_memory , cpu_memory ,
117+ tensor_parallel_size )
0 commit comments