Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit faed3eb

Browse files
robertgshaw2-redhatpcmoritzWoosukKwonmgointlrmchlsmth
authored andcommitted
[Kernel] Support Fp8 Checkpoints (Dynamic + Static) (vllm-project#4332)
Co-authored-by: Philipp Moritz <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: mgoin <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Cody Yu <[email protected]>
1 parent 40b286f commit faed3eb

File tree

3 files changed

+307
-40
lines changed

3 files changed

+307
-40
lines changed

tests/models/test_fp8.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# flake8: noqa
2+
"""Tests fp8 models against ground truth generation
3+
Note: these tests will only pass on L4 GPU.
4+
"""
5+
import os
6+
7+
import pytest
8+
import torch
9+
from transformers import AutoTokenizer
10+
11+
from vllm import LLM, SamplingParams
12+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13+
14+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
15+
16+
MAX_MODEL_LEN = 1024
17+
18+
MODELS = [
19+
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
20+
"meta-llama/Meta-Llama-3-8B-Instruct",
21+
]
22+
23+
EXPECTED_STRS_MAP = {
24+
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [
25+
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
26+
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
27+
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
28+
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
29+
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
30+
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
31+
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
32+
'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**'
33+
],
34+
"meta-llama/Meta-Llama-3-8B-Instruct": [
35+
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
36+
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
37+
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
38+
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
39+
'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of',
40+
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
41+
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
42+
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
43+
],
44+
}
45+
46+
capability = torch.cuda.get_device_capability()
47+
capability = capability[0] * 10 + capability[1]
48+
fp8_not_supported = (capability <
49+
QUANTIZATION_METHODS["fp8"].get_min_capability())
50+
51+
52+
@pytest.mark.skipif(fp8_not_supported,
53+
reason="fp8 is not supported on this GPU type.")
54+
@pytest.mark.parametrize("model_name", MODELS)
55+
def test_models(
56+
example_prompts,
57+
model_name,
58+
) -> None:
59+
model = LLM(model=model_name,
60+
max_model_len=MAX_MODEL_LEN,
61+
enforce_eager=True,
62+
quantization="fp8")
63+
64+
tokenizer = AutoTokenizer.from_pretrained(model_name)
65+
formatted_prompts = [
66+
tokenizer.apply_chat_template([{
67+
"role": "user",
68+
"content": prompt
69+
}],
70+
tokenize=False,
71+
add_generation_prompt=True)
72+
for prompt in example_prompts
73+
]
74+
75+
params = SamplingParams(max_tokens=20, temperature=0)
76+
generations = []
77+
# Note: these need to be run 1 at a time due to numerical precision,
78+
# since the expected strs were generated this way.
79+
for prompt in formatted_prompts:
80+
outputs = model.generate(prompt, params)
81+
generations.append(outputs[0].outputs[0].text)
82+
del model
83+
84+
print(generations)
85+
expected_strs = EXPECTED_STRS_MAP[model_name]
86+
for i in range(len(example_prompts)):
87+
generated_str = generations[i]
88+
expected_str = expected_strs[i]
89+
assert expected_str == generated_str, (
90+
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")

vllm/model_executor/layers/linear.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def __init__(
248248
self.register_parameter("bias", None)
249249

250250
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
251+
# Special case for Fp8 scales.
252+
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
253+
None)
254+
251255
tp_rank = get_tensor_model_parallel_rank()
252256
output_dim = getattr(param, "output_dim", None)
253257
param_data = param.data
@@ -256,6 +260,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
256260
start_idx = tp_rank * shard_size
257261
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
258262
shard_size)
263+
# Special case for Fp8 scales.
264+
elif fp8_scales_shard_indexer is not None:
265+
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
266+
loaded_weight,
267+
shard_id=0)
268+
259269
assert param_data.shape == loaded_weight.shape
260270
param_data.copy_(loaded_weight)
261271

@@ -325,7 +335,12 @@ def weight_loader(self,
325335

326336
param_data = param.data
327337
output_dim = getattr(param, "output_dim", None)
338+
# Special case for AQLM codebooks.
328339
is_metadata = getattr(param, "is_metadata", False)
340+
# Special case for Fp8 scales.
341+
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
342+
None)
343+
329344
if loaded_shard_id is None:
330345
# Loaded weight is already packed.
331346
if output_dim is None:
@@ -339,14 +354,13 @@ def weight_loader(self,
339354
current_shard_offset += output_size
340355
packed_dim = getattr(param, "packed_dim", None)
341356
for shard_id, shard_offset, shard_size in shard_offsets:
357+
# Special case for Quantization.
342358
# If quantized, we need to adjust the offset and size to account
343359
# for the packing.
344360
if packed_dim == output_dim:
345361
shard_size = shard_size // param.pack_factor
346362
shard_offset = shard_offset // param.pack_factor
347-
348-
# If marlin, we need to adjust the offset and size to
349-
# account for the tiling.
363+
# Special case for Marlin.
350364
shard_size, shard_offset = adjust_marlin_shard(
351365
param, shard_size, shard_offset)
352366

@@ -361,15 +375,14 @@ def weight_loader(self,
361375
if output_dim is not None:
362376
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
363377
shard_size = self.output_sizes[loaded_shard_id] // tp_size
378+
# Special case for quantization.
364379
# If quantized, we need to adjust the offset and size to account
365380
# for the packing.
366381
packed_dim = getattr(param, "packed_dim", None)
367382
if packed_dim == output_dim:
368383
shard_size = shard_size // param.pack_factor
369384
shard_offset = shard_offset // param.pack_factor
370-
371-
# If marlin, we need to adjust the offset and size to
372-
# account for the tiling.
385+
# Special case for Marlin.
373386
shard_size, shard_offset = adjust_marlin_shard(
374387
param, shard_size, shard_offset)
375388

@@ -378,11 +391,17 @@ def weight_loader(self,
378391
start_idx = tp_rank * shard_size
379392
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
380393
shard_size)
394+
# Special case for AQLM codebooks.
381395
elif is_metadata:
382396
# metadata indicates fixed size concatenated along dim 0
383397
shard_size = loaded_weight.shape[0]
384398
shard_offset = loaded_shard_id * shard_size
385399
param_data = param_data.narrow(0, shard_offset, shard_size)
400+
# Special case for Fp8 scales.
401+
elif fp8_scales_shard_indexer is not None:
402+
param_data, loaded_weight = fp8_scales_shard_indexer(
403+
param_data, loaded_weight, loaded_shard_id)
404+
386405
else:
387406
ignore_warning = getattr(param, "ignore_warning", False)
388407
if not ignore_warning:
@@ -477,7 +496,11 @@ def weight_loader(self,
477496
loaded_shard_id: Optional[str] = None):
478497
param_data = param.data
479498
output_dim = getattr(param, "output_dim", None)
499+
# Special case for AQLM codebooks.
480500
is_metadata = getattr(param, "is_metadata", False)
501+
# Special case for Fp8 scales.
502+
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
503+
None)
481504

482505
if loaded_shard_id is None:
483506
# Loaded weight is already packed.
@@ -495,14 +518,14 @@ def weight_loader(self,
495518
]
496519
packed_dim = getattr(param, "packed_dim", None)
497520
for shard_id, shard_offset, shard_size in shard_offsets:
521+
# Special case for Quantized Weights.
498522
# If quantized, we need to adjust the offset and size to account
499523
# for the packing.
500524
if packed_dim == output_dim:
501525
shard_size = shard_size // param.pack_factor
502526
shard_offset = shard_offset // param.pack_factor
503527

504-
# If marlin, we need to adjust the offset and size to
505-
# account for the tiling.
528+
# Special case for Marlin.
506529
shard_size, shard_offset = adjust_marlin_shard(
507530
param, shard_size, shard_offset)
508531

@@ -524,15 +547,15 @@ def weight_loader(self,
524547
shard_offset = (self.num_heads +
525548
self.num_kv_heads) * self.head_size
526549
shard_size = self.num_kv_heads * self.head_size
550+
# Special case for Quantized Weights.
527551
# If quantized, we need to adjust the offset and size to account
528552
# for the packing.
529553
packed_dim = getattr(param, "packed_dim", None)
530554
if packed_dim == output_dim:
531555
shard_size = shard_size // param.pack_factor
532556
shard_offset = shard_offset // param.pack_factor
533557

534-
# If marlin, we need to adjust the offset and size to
535-
# account for the tiling.
558+
# Special case for Marlin.
536559
shard_size, shard_offset = adjust_marlin_shard(
537560
param, shard_size, shard_offset)
538561

@@ -545,12 +568,17 @@ def weight_loader(self,
545568
start_idx = shard_id * shard_size
546569
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
547570
shard_size)
571+
# Special case for for AQLM codebooks.
548572
elif is_metadata:
549573
# metadata indicates fixed size concatenated along dim 0
550574
shard_size = loaded_weight.shape[0]
551575
shard_index = ["q", "k", "v"].index(loaded_shard_id)
552576
param_data = param_data.narrow(0, shard_index * shard_size,
553577
shard_size)
578+
# Special case for Fp8 scales.
579+
elif fp8_scales_shard_indexer is not None:
580+
param_data, loaded_weight = fp8_scales_shard_indexer(
581+
param_data, loaded_weight, loaded_shard_id)
554582
else:
555583
ignore_warning = getattr(param, "ignore_warning", False)
556584
if not ignore_warning:
@@ -642,6 +670,10 @@ def __init__(
642670
self.register_parameter("bias", None)
643671

644672
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
673+
# Special case for Fp8 scales.
674+
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
675+
None)
676+
645677
tp_rank = get_tensor_model_parallel_rank()
646678
input_dim = getattr(param, "input_dim", None)
647679
param_data = param.data
@@ -650,6 +682,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
650682
start_idx = tp_rank * shard_size
651683
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
652684
shard_size)
685+
# Special case for Fp8 scales.
686+
elif fp8_scales_shard_indexer is not None:
687+
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
688+
loaded_weight,
689+
shard_id=0)
690+
653691
assert param_data.shape == loaded_weight.shape
654692
param_data.copy_(loaded_weight)
655693

0 commit comments

Comments
 (0)