Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
79c94a1
fixed fp8 conflict with aqlm
Apr 23, 2024
f8b57e4
added quantization tests to buildkite
Apr 23, 2024
7175e5b
removed commented out piece
Apr 23, 2024
7a7520d
model loaded!
Apr 23, 2024
e0b4d72
renamed
Apr 23, 2024
f96428e
stash
Apr 24, 2024
88ba83b
added static fp8
Apr 24, 2024
0848b25
to try with torch.scaled_mm
Apr 24, 2024
15882ea
stash
Apr 24, 2024
7e6b675
added way to do weight quantization
Apr 24, 2024
cc959ea
working!
Apr 24, 2024
8d68dbc
fixed llama
Apr 24, 2024
881fc65
fixed llama again
Apr 24, 2024
e6dd46f
updated names
Apr 24, 2024
7e3933b
nit
Apr 24, 2024
453a236
cleanup
Apr 24, 2024
310e0a7
cleanup
Apr 24, 2024
ab4cb02
missed file :)
Apr 24, 2024
2edd93a
Update fp8.py
robertgshaw2-redhat Apr 24, 2024
ccee5d3
Implement static scaling for Mixtral
pcmoritz Apr 24, 2024
8f71c79
fix
pcmoritz Apr 24, 2024
6eb01e0
update
pcmoritz Apr 24, 2024
dc89cbc
fix
pcmoritz Apr 24, 2024
be60845
update
pcmoritz Apr 24, 2024
4613cb5
update
pcmoritz Apr 24, 2024
3d95d86
fix
pcmoritz Apr 24, 2024
642763f
move
pcmoritz Apr 24, 2024
706e931
update
pcmoritz Apr 24, 2024
9a3c78c
lol
pcmoritz Apr 24, 2024
1b6f020
fix cuda graph
pcmoritz Apr 24, 2024
b09bcec
fix
pcmoritz Apr 24, 2024
052e2b3
update
pcmoritz Apr 24, 2024
b33c6d7
update
pcmoritz Apr 25, 2024
475f58d
refactor
pcmoritz Apr 25, 2024
56b4880
update
pcmoritz Apr 25, 2024
be37154
revert
pcmoritz Apr 25, 2024
9c54d19
format
pcmoritz Apr 25, 2024
c5155ea
Update vllm/_custom_ops.py
pcmoritz Apr 25, 2024
948cca7
Update vllm/model_executor/layers/fused_moe/fused_moe.py
pcmoritz Apr 25, 2024
3feb887
Update vllm/model_executor/models/mixtral.py
pcmoritz Apr 25, 2024
df16316
format
pcmoritz Apr 25, 2024
7b6b0fa
support static scales
Apr 25, 2024
1a3b2e1
fixed example
Apr 25, 2024
63ad2ef
Delete quantize.ipynb
robertgshaw2-redhat Apr 25, 2024
794f1a1
Update vllm/_custom_ops.py
pcmoritz Apr 25, 2024
c13b6a4
update
pcmoritz Apr 25, 2024
5a230ed
update
pcmoritz Apr 25, 2024
80069c9
format
pcmoritz Apr 25, 2024
5ce17d0
activation_scale -> act_scale
pcmoritz Apr 25, 2024
5fc0335
Update scheme->activation_scheme
mgoin Apr 25, 2024
92d5162
fix dynamic scaling -- need init to zero due to atomic update
pcmoritz Apr 25, 2024
e1bfe10
Format
mgoin Apr 25, 2024
7242600
Fix tuple type
mgoin Apr 25, 2024
8512513
Merge remote-tracking branch 'pcmoritz/mixtral-fp8-static' into fp8-s…
Apr 26, 2024
21ddbb4
stash tyler's state
Apr 26, 2024
d27015c
stash
Apr 26, 2024
1111f87
cutlass working, but slow jitting on hotpath
Apr 26, 2024
f5d32ae
first end to end run with mixtral
Apr 26, 2024
924e8ce
added missed file
Apr 26, 2024
823a2e7
Update run_fp8.py
mgoin Apr 26, 2024
81f42be
Dynamic FP8 works, but static does not (#213)
robertgshaw2-redhat Apr 27, 2024
1a4fd8a
static correctness
Apr 27, 2024
e48c981
static fp8 loading
Apr 27, 2024
02f683e
working for dense models
Apr 27, 2024
81b73ef
Update weight_utils.py
robertgshaw2-redhat Apr 27, 2024
58dbe0f
moving mixtral updates to separate pr
Apr 27, 2024
6068dc5
Merge branch 'main' into fp8-static
robertgshaw2-redhat Apr 27, 2024
a8d4b33
make ./format pass
Apr 27, 2024
5be0970
better comments in linear.py
Apr 27, 2024
ef7992b
better comments in linear.py
Apr 27, 2024
0667791
fixed opt-125
Apr 27, 2024
d8adf14
removed run_fp8.py
Apr 27, 2024
9bb1a2b
format
Apr 27, 2024
169c9ed
Cleanup opt.py
mgoin Apr 27, 2024
8ef9c7d
added testing
Apr 27, 2024
c7d6dd6
./format.sh
Apr 27, 2024
50b5823
fixed typing
Apr 27, 2024
4156ca9
fixed typing
Apr 27, 2024
3148fc9
added warning format
Apr 27, 2024
7846d67
Update opt.py
robertgshaw2-redhat Apr 27, 2024
ba408c6
formatted
Apr 27, 2024
04617fd
Update vllm/model_executor/layers/quantization/fp8.py
robertgshaw2-redhat Apr 27, 2024
cc3d395
Update vllm/model_executor/layers/quantization/fp8.py
robertgshaw2-redhat Apr 27, 2024
6005ed2
baseline mixtral loading but not correct
Apr 28, 2024
5ca78f1
Merge branch 'fp8-static' into fp8-mixtral
Apr 28, 2024
51a686b
mixtral working end-to-end
Apr 28, 2024
e01833c
added test
Apr 28, 2024
03312e4
added test
Apr 28, 2024
171fcc9
format. Codespell not happy
Apr 28, 2024
b74b0a4
removed test b/c cannot get codespell to pass
Apr 28, 2024
233963b
Update format.sh
robertgshaw2-redhat Apr 28, 2024
f60aa36
formatted
Apr 28, 2024
82a8736
Merge remote-tracking branch 'upstream/main' into fp8-mixtral
mgoin May 1, 2024
c5a68fb
Fix mixtral definition
mgoin May 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,8 @@ def initialize_dummy_weights(
for param in model.state_dict().values():
if torch.is_floating_point(param):
param.data.uniform_(low, high)


def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
140 changes: 100 additions & 40 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
all_close_1d, default_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
Expand Down Expand Up @@ -78,6 +79,8 @@ def __init__(
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config

# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(quant_config, Fp8Config)
Expand All @@ -86,24 +89,28 @@ def __init__(
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype

# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None)

if self.use_fp8:
params_dtype = torch.float8_e4m3fn

self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
dtype=self.params_dtype))
dtype=params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype))
dtype=params_dtype))

set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
Expand All @@ -112,33 +119,57 @@ def __init__(
"weight_loader": self.weight_loader,
})

# Scaling factors for FP8 weights
self.ws_scale = nn.Parameter(
torch.ones(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
self.w2s_scale = nn.Parameter(
torch.ones(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None

# Scaling factors for FP8 activations
need_act_scales = (self.use_fp8
and quant_config.activation_scheme == "static")
self.as_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
requires_grad=False) if need_act_scales else None
self.a2s_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
requires_grad=False) if need_act_scales else None

if need_act_scales:
set_weight_attrs(self.as_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2s_scale, {
"weight_loader": self.weight_loader,
})
# Used for fp8.
self.ws_scale = None
self.w2s_scale = None
self.as_scale = None
self.a2s_scale = None

if self.use_fp8:
# WEIGHT_SCALE (for fp8)
self.ws_scale = nn.Parameter(torch.ones(self.num_total_experts,
device="cuda",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove all device="cuda" in this or the next PR.

dtype=torch.float32),
requires_grad=False)
self.w2s_scale = nn.Parameter(torch.ones(self.num_total_experts,
device="cuda",
dtype=torch.float32),
requires_grad=False)

# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.ws_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s_scale, {
"weight_loader": self.weight_loader,
})

# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.as_scale = nn.Parameter(
torch.zeros(self.num_total_experts,
device="cuda",
dtype=torch.float32),
requires_grad=False)
self.a2s_scale = nn.Parameter(
torch.zeros(self.num_total_experts,
device="cuda",
dtype=torch.float32),
requires_grad=False)

set_weight_attrs(self.as_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2s_scale, {
"weight_loader": self.weight_loader,
})

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
Expand All @@ -153,11 +184,20 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name:
param_data[:] = param_data[:].max(loaded_weight)
if "act_scale" in weight_name or "weight_scale" in weight_name:
param_data[expert_id] = loaded_weight

def _all_close_1d(self, x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

def process_weights_after_loading(self):
if self.use_fp8:
# Fp8 is the only case where we need to process after loading.
if not self.use_fp8:
return

# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
Expand All @@ -168,6 +208,26 @@ def process_weights_after_loading(self):
self.ws = nn.Parameter(ws, requires_grad=False)
self.w2s = nn.Parameter(w2s, requires_grad=False)

# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif self.quant_config.activation_scheme == "static":
if self.as_scale is None or self.a2s_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")

if (not all_close_1d(self.as_scale)
or not all_close_1d(self.a2s_scale)):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ")

self.as_scale = nn.Parameter(self.as_scale.max(),
requires_grad=False)
self.a2s_scale = nn.Parameter(self.a2s_scale.max(),
requires_grad=False)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
Expand Down Expand Up @@ -226,13 +286,6 @@ def __init__(self,
self.rope_theta = rope_theta
self.sliding_window = sliding_window

if isinstance(quant_config, Fp8Config):
print_warning_once(
"For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved."
)
quant_config = None

self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
Expand Down Expand Up @@ -465,6 +518,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]

expert_params_mapping = [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("ws_scale" if weight_name in ["w1", "w3"] else "w2s_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
Expand Down