Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/mlc_llm/interface/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get():

@tvm.register_func("mlc_llm.calibration_observer")
@staticmethod
def callback(name, mode, value, out_value):
def callback(name: str, mode: str, value: "tvm.nd.NDArray", out_value: "tvm.nd.NDArray"):
"""The callback function to update the saved calibration parameters."""
instance = CalibrationObserver.get()
if mode == "max":
Expand All @@ -48,7 +48,7 @@ def save_params(self, output: str):
tvmjs.dump_ndarray_cache(
self.params,
output,
encode_format="raw",
encode_format="f32-to-bf16",
meta_data=None,
show_progress=False,
update_if_exists=True,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/op/moe_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def access_x(x, e, j):
def _func_with_scale(
x: T.Buffer((x_leading_dim, in_features), model_dtype),
w: T.Buffer((local_experts, out_features, num_storage), storage_dtype),
scale: T.Buffer((1,), model_dtype),
scale: T.Buffer((1,), "float32"),
indptr: T.Buffer((1, experts_per_tok), "int32"),
o: T.Buffer((experts_per_tok, out_features), model_dtype),
):
Expand Down
8 changes: 4 additions & 4 deletions python/mlc_llm/quantization/fp8_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa
[f"{self.name}.q_calibration_scale", "max", x_scale],
out=nn.Tensor.placeholder(x_scale.shape, x_scale.dtype),
)
x_q = (x / x_scale).astype(self.config.activation_dtype)
x = x_q.astype(self.config.model_dtype) * x_scale
x_q = (x / x_scale.astype(x.dtype)).astype(self.config.activation_dtype)
x = x_q.astype(self.config.model_dtype) * x_scale.astype(self.config.model_dtype)

if indptr.ndim == 2:
assert indptr.shape[0] == 1
Expand All @@ -97,12 +97,12 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa
if extern.get_store().cutlass_group_gemm:
if self.config.calibration_mode == "inference":
if self.q_calibration_scale is not None:
x /= self.q_calibration_scale
x /= self.q_calibration_scale.astype(x.dtype)
x_q = nn.op.astype(x, dtype=self.config.activation_dtype)
x_scale = self.q_calibration_scale

scale = (
(x_scale * self.q_scale).astype("float32")
x_scale * self.q_scale
if self.q_scale is not None
else nn.wrap_nested(
relax.Constant(nd.array(np.array([1.0]).astype("float32"))), "scale"
Expand Down
47 changes: 20 additions & 27 deletions python/mlc_llm/quantization/per_tensor_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,14 @@ def quantize_float8( # pylint: disable=too-many-locals
) -> Union[Tuple[nn.Tensor], Tuple[nn.Tensor, nn.Tensor]]:
"""Per-tensor quantization for weight tensor, defined in tensor expression."""

# quantize_dtype = DataType(quantize_dtype)

if self.use_scale:
# min_scaling_factor taken from TRT-LLM
def _compute_scale(x: te.Tensor) -> te.Tensor:
max_abs = topi.max(topi.abs(x))
min_scaling_factor = tir.const(1.0 / (self.max_int_value * 512.0), self.model_dtype)
scale = topi.maximum(
max_abs.astype(self.model_dtype) / self.max_int_value, min_scaling_factor
)
).astype("float32")
scale = topi.expand_dims(scale, axis=0)
return scale

Expand Down Expand Up @@ -315,7 +313,7 @@ def dequantize_float8(
else:
dequantized_tensor = q_tensor.astype(self.model_dtype)
if scale is not None:
dequantized_tensor = dequantized_tensor * scale
dequantized_tensor = dequantized_tensor * scale.astype(dequantized_tensor.dtype)
return dequantized_tensor


Expand Down Expand Up @@ -343,9 +341,9 @@ def __init__( # pylint: disable=too-many-arguments
)
self.q_calibration_scale = None
if config.use_scale:
self.q_scale = nn.Parameter((1,), config.model_dtype)
self.q_scale = nn.Parameter((1,), "float32")
if config.calibration_mode == "inference":
self.q_calibration_scale = nn.Parameter((1,), config.model_dtype)
self.q_calibration_scale = nn.Parameter((1,), "float32")
else:
self.q_scale = None
if bias:
Expand Down Expand Up @@ -412,7 +410,7 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name
# Note: Use calibration scale when calibration is enabled
if self.config.calibration_mode == "inference":
if self.q_calibration_scale:
x /= self.q_calibration_scale
x /= self.q_calibration_scale.astype(x.dtype)
x_q = x.astype(self.config.activation_dtype)
x_scale = self.q_calibration_scale
elif self.config.calibration_mode == "max":
Expand All @@ -428,25 +426,21 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name
[f"{self.name}.q_calibration_scale", "max", x_scale],
out=nn.Tensor.placeholder(x_scale.shape, x_scale.dtype),
)
x_q = (x / x_scale).astype(self.config.activation_dtype)
x_q = (x / x_scale.astype(x.dtype)).astype(self.config.activation_dtype)
x = x_q.astype(self.config.model_dtype) * x_scale.astype(self.config.model_dtype)
else:
raise ValueError(f"Unknown calibration mode: {self.config.calibration_mode}")

if self.config.weight_dtype == self.config.storage_dtype and not self.config.use_scale:
w = self.q_weight
w = nn.op.permute_dims(w)
x = nn.op.matmul(
x_q, w, out_dtype=self.out_dtype
) # mixed precision matmul: fp8 * fp8 => fp16
if (
self.config.weight_dtype == self.config.storage_dtype
and self.config.calibration_mode == "inference"
):
x = nn.op.matmul(x_q, nn.permute_dims(self.q_weight), out_dtype="float32")
if self.config.use_scale:
scale = x_scale * self.q_scale
x = x * scale
x = x.astype(self.out_dtype)
else:
# dequantize input and weight to fp16, this can be fused into matmul during lowering
x = nn.op.tensor_expr_op(
lambda quantized_x, scale: self.config._dequantize( # pylint: disable=protected-access
quantized_x, scale, out_shape=x.shape
),
"dequantize_x",
args=[x_q, x_scale],
)
w = nn.op.tensor_expr_op(
lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access
weight,
Expand All @@ -463,8 +457,7 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name
"dequantize",
args=[self.q_weight, self.q_scale],
)
w = nn.op.permute_dims(w)
x = nn.op.matmul(x, w, out_dtype=self.out_dtype)
x = nn.op.matmul(x, nn.permute_dims(w), out_dtype=self.out_dtype)
if self.bias is not None:
x = x + self.bias
return x
Expand Down Expand Up @@ -494,7 +487,7 @@ def __init__(self, num: Union[int, tir.Var], dim: int, config: PerTensorQuantize
(num, tir.ceildiv(dim, config.num_elem_per_storage)), config.storage_dtype
)
if self.config.use_scale:
self.q_scale = nn.Parameter((1,), config.model_dtype)
self.q_scale = nn.Parameter((1,), "float32")
else:
self.q_scale = None

Expand Down Expand Up @@ -612,9 +605,9 @@ def __init__(
)
self.q_calibration_scale = None
if config.use_scale:
self.q_scale = nn.Parameter((1,), config.model_dtype)
self.q_scale = nn.Parameter((1,), "float32")
if config.calibration_mode == "inference":
self.q_calibration_scale = nn.Parameter((1,), config.model_dtype)
self.q_calibration_scale = nn.Parameter((1,), "float32")
else:
self.q_scale = None

Expand Down