Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,13 +812,12 @@ def quantized_checkpoint() -> Generator:
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"):
continue
# TODO: only nf4 quantization is supported for now
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
raise NotImplementedError(
"Only bitsandbytes_nf4 quantization"
f"is supported for now. {weight_name} is fp4 quantized"
)
temp_state_dict[weight_name] = weight_tensor
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if "quant_state.bitsandbytes" in weight_name:
temp_state_dict[weight_name] = weight_tensor.cpu().data
else:
temp_state_dict[weight_name] = weight_tensor

# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
Expand All @@ -827,12 +826,6 @@ def _parse_quant_state(param_name: str,
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state[param_name +
".quant_state.bitsandbytes__nf4"] = quant_state[
param_name +
".quant_state.bitsandbytes__nf4"].cpu().data
return QuantState.from_dict(quant_state, device="cuda")

# Second iterate over all prequant and normal weights
Expand All @@ -842,8 +835,12 @@ def _parse_quant_state(param_name: str,
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"):
continue
if weight_name + ".quant_state.bitsandbytes__nf4" \
in temp_state_dict:

if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
in temp_state_dict) or \
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
in temp_state_dict):

quant_state = _parse_quant_state(weight_name,
temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
Expand Down