Skip to content

Commit add055e

Browse files
authored
Enhance model loader (#83)
1 parent 7c041ab commit add055e

File tree

2 files changed

+56
-42
lines changed

2 files changed

+56
-42
lines changed

cacheflow/core/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from cacheflow.frontend.simple_frontend import SimpleFrontend
1313
from cacheflow.logger import init_logger
1414
from cacheflow.model_executor import get_memory_analyzer
15-
from cacheflow.sequence import SequenceGroup
1615
from cacheflow.sampling_params import SamplingParams
16+
from cacheflow.sequence import SequenceGroup
1717
from cacheflow.utils import get_gpu_memory, get_cpu_memory
1818
from cacheflow.worker.controller import Controller, DeviceID
1919

cacheflow/model_executor/model_loader.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,51 @@
1414
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
1515

1616

17-
_MODELS = {
18-
'gpt2': GPT2LMHeadModel,
19-
'llama': LlamaForCausalLM,
20-
'opt': OPTForCausalLM,
21-
'stablelm': GPTNeoXForCausalLM,
22-
'pythia': GPTNeoXForCausalLM,
23-
'dolly-v2': GPTNeoXForCausalLM,
17+
# TODO(woosuk): Lazy-load the model classes.
18+
_MODEL_REGISTRY = {
19+
"GPT2LMHeadModel": GPT2LMHeadModel,
20+
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
21+
"LlamaForCausalLM": LlamaForCausalLM,
22+
"OPTForCausalLM": OPTForCausalLM,
2423
}
2524

2625
_MEMORY_ANALYZERS = {
27-
'gpt2': GPT2MemoryAnalyzer,
28-
'llama': LlamaMemoryAnalyzer,
29-
'opt': OPTMemoryAnalyzer,
30-
'stablelm': GPTNeoXMemoryAnalyzer,
31-
'pythia': GPTNeoXMemoryAnalyzer,
32-
'dolly-v2': GPTNeoXMemoryAnalyzer,
26+
"GPT2LMHeadModel": GPT2MemoryAnalyzer,
27+
"GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
28+
"LlamaForCausalLM": LlamaMemoryAnalyzer,
29+
"OPTForCausalLM": OPTMemoryAnalyzer,
3330
}
3431

3532

33+
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
34+
architectures = getattr(config, "architectures", [])
35+
for arch in architectures:
36+
if arch in _MODEL_REGISTRY:
37+
return _MODEL_REGISTRY[arch]
38+
raise ValueError(
39+
f"Model architectures {architectures} are not supported for now. "
40+
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
41+
)
42+
43+
44+
def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
45+
architectures = getattr(config, "architectures", [])
46+
for arch in architectures:
47+
if arch in _MEMORY_ANALYZERS:
48+
return _MEMORY_ANALYZERS[arch]
49+
raise ValueError(
50+
f"Model architectures {architectures} are not supported for now. "
51+
f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
52+
)
53+
54+
3655
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
37-
# NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
56+
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
3857
# because config.torch_dtype can be None.
39-
config_dtype = getattr(config, 'torch_dtype', None)
58+
config_dtype = getattr(config, "torch_dtype", None)
4059
if config_dtype is None:
4160
config_dtype = torch.float32
42-
if dtype == 'default':
61+
if dtype == "default":
4362
if config_dtype == torch.float32:
4463
# Following the common practice, we use float16 for float32 models.
4564
torch_dtype = torch.float16
@@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
5170
# TODO(woosuk): Allow using float16 for bfloat16 models and
5271
# vice versa. Print a warning message and continue.
5372
raise ValueError(
54-
f'Cannot use {torch_dtype} for {config_dtype} model.')
73+
f"Cannot use {torch_dtype} for {config_dtype} model.")
5574
return torch_dtype
5675

5776

@@ -65,24 +84,21 @@ def get_model(
6584
config = AutoConfig.from_pretrained(model_name)
6685
torch_dtype = _get_dtype(config, dtype)
6786
torch.set_default_dtype(torch_dtype)
68-
for model_class_name, model_class in _MODELS.items():
69-
if model_class_name in model_name:
70-
if use_dummy_weights:
71-
# Create a model instance.
72-
# The weights will be initialized as empty tensors.
73-
model = model_class(config)
74-
model = model.cuda()
75-
# NOTE(woosuk): For precise performance evaluation, we assign
76-
# random values to the weights.
77-
initialize_dummy_weights(model)
78-
else:
79-
# Create a model instance.
80-
model = model_class(config)
81-
# Load the weights from the cached or downloaded files.
82-
model.load_weights(model_name, cache_dir, use_np_cache)
83-
model = model.cuda()
84-
return model.eval(), torch_dtype
85-
raise ValueError(f'Unsupported model name: {model_name}')
87+
model_class = _get_model_architecture(config)
88+
89+
# Create a model instance.
90+
# The weights will be initialized as empty tensors.
91+
model = model_class(config)
92+
if use_dummy_weights:
93+
model = model.cuda()
94+
# NOTE(woosuk): For accurate performance evaluation, we assign
95+
# random values to the weights.
96+
initialize_dummy_weights(model)
97+
else:
98+
# Load the weights from the cached or downloaded files.
99+
model.load_weights(model_name, cache_dir, use_np_cache)
100+
model = model.cuda()
101+
return model.eval(), torch_dtype
86102

87103

88104
def get_memory_analyzer(
@@ -95,9 +111,7 @@ def get_memory_analyzer(
95111
) -> CacheFlowMemoryAnalyzer:
96112
config = AutoConfig.from_pretrained(model_name)
97113
torch_dtype = _get_dtype(config, dtype)
98-
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
99-
if model_class in model_name:
100-
return memory_analyzer(
101-
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
102-
tensor_parallel_size)
103-
raise ValueError(f'Unsupported model name: {model_name}')
114+
memory_analyzer = _get_memory_analyzer(config)
115+
return memory_analyzer(
116+
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
117+
tensor_parallel_size)

0 commit comments

Comments
 (0)