diff --git a/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp b/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp index 9fa36d16b8e..f2255604e21 100644 --- a/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp +++ b/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp @@ -44,51 +44,107 @@ namespace torch_ext { -W4A16GemmRunner::W4A16GemmRunner(at::ScalarType activationDtype, int64_t quant_mode) +finegrainedMixedDtypeGemmRunner::finegrainedMixedDtypeGemmRunner( + at::ScalarType activationDtype, at::ScalarType outputDtype, int64_t quant_mode) : mActivationDtype(activationDtype) + , mOutputDtype(outputDtype) { if (quant_mode == 0) { if (activationDtype == at::ScalarType::Half) { + TORCH_CHECK( + outputDtype == activationDtype, "Activation dtype needs to match Output stype", activationDtype); mGemmRunner = std::make_shared>(); } else if (activationDtype == at::ScalarType::BFloat16) { + TORCH_CHECK( + outputDtype == activationDtype, "Activation dtype needs to match Output stype", activationDtype); mGemmRunner = std::make_shared< tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16>>(); } + + else if (activationDtype == at::ScalarType::Float8_e4m3fn) + { + if (outputDtype == at::ScalarType::BFloat16) + { + mGemmRunner = std::make_shared< + tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, __nv_bfloat16, __nv_bfloat16>>(); + } + else if (outputDtype == at::ScalarType::Half) + { + mGemmRunner + = std::make_shared>(); + } + else + { + TORCH_CHECK(false, "Unsupported output dtype for Float8_e4m3fn activation", outputDtype); + } + } + else + { + TORCH_CHECK(false, "Unsupported activation dtype", activationDtype); + } } + else if (quant_mode == 1) { if (activationDtype == at::ScalarType::Half) { + TORCH_CHECK( + outputDtype == activationDtype, "Activation dtype needs to match Output stype", activationDtype); mGemmRunner = std::make_shared>(); } else if (activationDtype == at::ScalarType::BFloat16) { + TORCH_CHECK( + outputDtype == activationDtype, "Activation dtype needs to match Output stype", activationDtype); mGemmRunner = std::make_shared>(); } + else if (activationDtype == at::ScalarType::Float8_e4m3fn) + { + if (outputDtype == at::ScalarType::BFloat16) + { + mGemmRunner = std::make_shared< + tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, __nv_bfloat16, __nv_bfloat16>>(); + } + else if (outputDtype == at::ScalarType::Half) + { + mGemmRunner = std::make_shared< + tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, half, half>>(); + } + else + { + TORCH_CHECK(false, "Unsupported output dtype for Float8_e4m3fn activation", outputDtype); + } + } } else { - TORCH_CHECK(false, "Unsupported quant mode for W4A16GemmRunner: ", quant_mode); + TORCH_CHECK(false, "Unsupported quant mode for finegrainedMixedDtypeGemmRunner: ", quant_mode); } - TORCH_CHECK(mGemmRunner, "Failed to create W4A16 GEMM runner for activation type ", c10::toString(activationDtype)); + TORCH_CHECK(mGemmRunner, "Failed to create finegrained Mixed Dtype GEMM runner for activation type ", + c10::toString(activationDtype)); mConfigs = mGemmRunner->getConfigs(); // Get configs via the interface - TORCH_CHECK(!mConfigs.empty(), "Failed to get CUTLASS configs for W4A16 GEMM with activation type ", + TORCH_CHECK(!mConfigs.empty(), "Failed to get CUTLASS configs for finegrainedMixedDtype GEMM with activation type ", c10::toString(activationDtype)); } -at::Tensor W4A16GemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_packed, at::Tensor const& scales, - int64_t group_size_long, int64_t configIdx, std::optional bias, std::optional zeros) const +at::Tensor finegrainedMixedDtypeGemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_packed, + at::Tensor const& scales, int64_t group_size_long, int64_t configIdx, std::optional bias, + std::optional zeros, double alpha) const { TORCH_CHECK(A.is_cuda() && B_packed.is_cuda() && scales.is_cuda(), "All input tensors must be on CUDA"); TORCH_CHECK(A.scalar_type() == mActivationDtype, "Activation tensor A's dtype ", c10::toString(A.scalar_type()), @@ -96,6 +152,7 @@ at::Tensor W4A16GemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_pac TORCH_CHECK(B_packed.scalar_type() == torch::kQUInt4x2 || B_packed.scalar_type() == torch::kInt8 || B_packed.scalar_type() == torch::kUInt8, "B_packed must be quint4x2, int8, or uint8 (view of quantized data)"); + TORCH_CHECK(A.is_contiguous() && B_packed.is_contiguous() && scales.is_contiguous(), "All input tensors (A, B_packed, scales) must be contiguous"); @@ -156,19 +213,18 @@ at::Tensor W4A16GemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_pac output_shape_vec.push_back(N_orig); } - // Set output dtype based on activation dtype torch::ScalarType output_dtype; - if (mActivationDtype == at::ScalarType::Half) + if (mOutputDtype == at::ScalarType::Half) { output_dtype = torch::kFloat16; } - else if (mActivationDtype == at::ScalarType::BFloat16) + else if (mOutputDtype == at::ScalarType::BFloat16) { output_dtype = torch::kBFloat16; } else { - TORCH_CHECK(false, "Unsupported activation type for output dtype determination"); + TORCH_CHECK(false, "Unsupported output dtype"); } torch::Tensor C_tensor = torch::empty(output_shape_vec, A.options().dtype(output_dtype)); @@ -201,16 +257,15 @@ at::Tensor W4A16GemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_pac cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index()); - mGemmRunner->gemm(A_ptr, B_ptr, scales_ptr, zeros_ptr, bias_ptr, - 1.0f, // alpha - C_ptr, M, N_orig, K, group_size, gemm_config_to_use, workspace_ptr, workspace_bytes, stream); + mGemmRunner->gemm(A_ptr, B_ptr, scales_ptr, zeros_ptr, bias_ptr, static_cast(alpha), C_ptr, M, N_orig, K, + group_size, gemm_config_to_use, workspace_ptr, workspace_bytes, stream); return C_tensor; } -int64_t W4A16GemmRunner::getNumConfigs() const +int64_t finegrainedMixedDtypeGemmRunner::getNumConfigs() const { - TORCH_CHECK(mGemmRunner, "W4A16GemmRunner not initialized properly."); + TORCH_CHECK(mGemmRunner, "finegrainedMixedDtypeGemmRunner not initialized properly."); return static_cast(mConfigs.size()); } @@ -218,8 +273,8 @@ int64_t W4A16GemmRunner::getNumConfigs() const TORCH_LIBRARY_FRAGMENT(trtllm, m) { - m.class_("W4A16GemmRunner") - .def(torch::init()) - .def("run_gemm", &torch_ext::W4A16GemmRunner::runGemm) - .def("get_num_configs", &torch_ext::W4A16GemmRunner::getNumConfigs); + m.class_("finegrainedMixedDtypeGemmRunner") + .def(torch::init()) + .def("run_gemm", &torch_ext::finegrainedMixedDtypeGemmRunner::runGemm) + .def("get_num_configs", &torch_ext::finegrainedMixedDtypeGemmRunner::getNumConfigs); } diff --git a/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h b/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h index 1b2083de5a0..5bda7be3eb6 100644 --- a/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h +++ b/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h @@ -24,14 +24,15 @@ namespace torch_ext { -class W4A16GemmRunner : public torch::CustomClassHolder +class finegrainedMixedDtypeGemmRunner : public torch::CustomClassHolder { public: - explicit W4A16GemmRunner(at::ScalarType activationDtype, int64_t quant_mode = 0); + explicit finegrainedMixedDtypeGemmRunner( + at::ScalarType activationDtype, at::ScalarType outputDtype, int64_t quant_mode = 0); at::Tensor runGemm(at::Tensor const& A, at::Tensor const& B_packed, at::Tensor const& scales, int64_t group_size_long, int64_t configIdx = -1, std::optional bias = std::nullopt, - std::optional zeros = std::nullopt) const; + std::optional zeros = std::nullopt, double alpha = 1.0f) const; int64_t getNumConfigs() const; @@ -39,6 +40,7 @@ class W4A16GemmRunner : public torch::CustomClassHolder std::shared_ptr mGemmRunner; std::vector mConfigs; at::ScalarType mActivationDtype; + at::ScalarType mOutputDtype; }; } // namespace torch_ext diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index ffeb90c2fd3..d2320feaa1b 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -675,24 +675,27 @@ def _( dtype=output_dtype) -class W4A16GemmRunner(TunableRunner): +class FinegrainedMixedDtypeGemm(TunableRunner): _runner_dict = dict() MAX_SUPPORTED_SM_VERSION = 90 - def __init__(self, activation_dtype: torch.dtype, quant_mode: int): - instance_key = (activation_dtype, quant_mode) - if instance_key not in W4A16GemmRunner._runner_dict: - W4A16GemmRunner._runner_dict[ - instance_key] = torch.classes.trtllm.W4A16GemmRunner( - activation_dtype, quant_mode) - self._w4a16_gemm_runner = W4A16GemmRunner._runner_dict[instance_key] + def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype, + quant_mode: int): + instance_key = (activation_dtype, output_dtype, quant_mode) + if instance_key not in FinegrainedMixedDtypeGemm._runner_dict: + FinegrainedMixedDtypeGemm._runner_dict[ + instance_key] = torch.classes.trtllm.finegrainedMixedDtypeGemmRunner( + activation_dtype, output_dtype, quant_mode) + self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[ + instance_key] def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - return list(range(self._w4a16_gemm_runner.get_num_configs())) + return list( + range(self._finegrained_mixed_dtype_gemm_runner.get_num_configs())) def forward(self, inputs: List[torch.Tensor], @@ -707,25 +710,25 @@ def forward(self, activation, weights_packed, scales = inputs - return self._w4a16_gemm_runner.run_gemm( - activation, - weights_packed, - scales, - kwargs["group_size"], - tactic, - kwargs["bias"], - kwargs["zeros"], - ) + alpha = 1.0 if kwargs.get("alpha") is None else kwargs["alpha"] + return self._finegrained_mixed_dtype_gemm_runner.run_gemm( + activation, weights_packed, scales, kwargs["group_size"], tactic, + kwargs["bias"], kwargs["zeros"], alpha) -@torch.library.custom_op("trtllm::w4a16_gemm", mutates_args=()) -def w4a16_gemm(input: torch.Tensor, - weight: torch.Tensor, - scales: torch.Tensor, - group_size: int, - has_zero_point: bool, - bias: Optional[torch.Tensor] = None, - zeros: Optional[torch.Tensor] = None) -> torch.Tensor: + +@torch.library.custom_op("trtllm::finegrained_mixed_dtype_gemm", + mutates_args=()) +def finegrained_mixed_dtype_gemm( + input: torch.Tensor, + weight: torch.Tensor, + scales: torch.Tensor, + group_size: int, + has_zero_point: bool, + output_dtype: torch.dtype, + alpha: Optional[float] = None, + bias: Optional[torch.Tensor] = None, + zeros: Optional[torch.Tensor] = None) -> torch.Tensor: assert not has_zero_point or zeros is not None, "Expected 'zeros' tensor when has_zero_point is True" @@ -741,16 +744,24 @@ def w4a16_gemm(input: torch.Tensor, if quant_mode == 0: assert zeros is None, "When quant_mode is 0 (FINEGRAINED_SCALE_ONLY), zeros must be None" - w4a16_gemm_runner = W4A16GemmRunner(input.dtype, quant_mode) + finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm( + input.dtype, output_dtype, quant_mode) + + kwargs = { + "group_size": group_size, + "zeros": zeros, + "bias": bias, + "alpha": alpha + } - kwargs = {"group_size": group_size, "zeros": zeros, "bias": bias} - _, best_tactic = tuner.choose_one("trtllm::w4a16_gemm::gemm", - [w4a16_gemm_runner], tuning_config, - [input, weight, scales], **kwargs) + _, best_tactic = tuner.choose_one( + "trtllm::finegrained_mixed_dtype_gemm::gemm", + [finegrained_mixed_dtype_gemm_runner], tuning_config, + [input, weight, scales], **kwargs) - return w4a16_gemm_runner(inputs=[input, weight, scales], - tactic=best_tactic, - **kwargs) + return finegrained_mixed_dtype_gemm_runner(inputs=[input, weight, scales], + tactic=best_tactic, + **kwargs) @torch.library.custom_op("trtllm::attention", mutates_args=()) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 134f1c8ebf8..3db075da4b2 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -47,6 +47,12 @@ class TensorParallelMode(str, enum.Enum): def split_dim(cls, mode): return 1 if mode == cls.ROW else 0 + # Helper to shard the corresponding per-channel activation scales + # Which shard along the dimension orthogonal to the weights + @classmethod + def flip(cls, mode): + return cls.ROW if mode == cls.COLUMN else cls.COLUMN + def load_weight_shard( weight, @@ -110,12 +116,13 @@ def load_weights_vanilla_helper(module: Linear, weights: List[Dict]): weight = load_weight_shard(weights[0]['weight'], module.tp_size, module.tp_rank, module.tp_mode, device) - if module.has_w4a16_awq: + if module.has_w4a16_awq or module.has_w4a8_awq: # NOTE: without the preprocess during the runtime, the gemm output nan's. in order to use the preprocess_weights_for_mixed_gemm # we need to cast the weight to int8 first. + activation_dtype = torch.float16 if module.has_w4a16_awq else torch.float8_e4m3fn weight = preprocess_weights_for_mixed_gemm( weight.T.to(torch.int8).contiguous().cpu(), torch.quint4x2, - torch.float16).cuda().contiguous() + activation_dtype).cuda().contiguous() copy_weight(module.weight, weight) @@ -894,7 +901,7 @@ def create_weights(self, module: Linear, in_features: int, f"for INT4 per-group quantization scale dimensions.") module.weight_scale = Parameter(torch.empty( - (out_features, in_features // group_size), dtype=dtype), + (in_features // group_size, out_features), dtype=dtype), requires_grad=False) # NOTE: Not in all linear we have this tensor - pre_quant_scale is computed as an average and merged with the # LayerNorm for QKV and Gate/Up projection layers when possible. we can see the tensor only for o_proj and down_proj @@ -910,19 +917,19 @@ def apply(self, module: Linear, input: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: if module.pre_quant_scale is not None: - pre_quant_scale = module.pre_quant_scale.repeat(input.shape[0], 1) - input = torch.mul(input, pre_quant_scale) + input = input * module.pre_quant_scale bias = bias.contiguous() if bias is not None else None - output = torch.ops.trtllm.w4a16_gemm(input.to( - module.dtype).contiguous(), - module.weight, - module.weight_scale.T.contiguous(), - module.quant_config.group_size, - module.quant_config.has_zero_point, - bias, - zeros=None) + output = torch.ops.trtllm.finegrained_mixed_dtype_gemm( + input=input.to(module.dtype).contiguous(), + weight=module.weight, + scales=module.weight_scale, + group_size=module.quant_config.group_size, + has_zero_point=module.quant_config.has_zero_point, + output_dtype=module.dtype or input.dtype, + bias=bias, + zeros=None) return output def load_weight_scales( @@ -955,9 +962,16 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: load_weights_vanilla_helper(module, weights) device = torch.device('cuda') - pre_quant_scale = load_weight_shard(weights[0]['pre_quant_scale'], - module.tp_size, module.tp_rank, - module.tp_mode, device) + + pre_quant_scale = load_weight_shard( + weights[0]["pre_quant_scale"], + module.tp_size, + module.tp_rank, + # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around + TensorParallelMode.flip(module.tp_mode), + device, + ) + module.pre_quant_scale = Parameter( torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype), requires_grad=False).to(device=device) @@ -967,7 +981,7 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: module.tp_mode, device) copy_weight(module.pre_quant_scale, pre_quant_scale) - copy_weight(module.weight_scale, weight_scale) + copy_weight(module.weight_scale, weight_scale.T.contiguous()) def load_weights_fused_qkv_linear(self, module: Linear, weights: List[Dict]) -> None: @@ -984,7 +998,7 @@ def load_weights_fused_qkv_linear(self, module: Linear, weight_scales = self.load_weight_scales(weights) # Create concatenated weight scale tensor - cat_weight_scale = torch.cat(weight_scales, dim=0) + cat_weight_scale = torch.cat(weight_scales, dim=0).T.contiguous() copy_weight(module.weight_scale, cat_weight_scale) def load_weights_fused_gate_up_linear(self, module: Linear, @@ -1006,10 +1020,250 @@ def load_weights_fused_gate_up_linear(self, module: Linear, right_scale = load_weight_shard(weights[1]['weight_scale'], module.tp_size, module.tp_rank, module.tp_mode, device).contiguous() - fused_scale = torch.cat([left_scale, right_scale], dim=0) + fused_scale = torch.cat([left_scale, right_scale], dim=0).T.contiguous() copy_weight(module.weight_scale, fused_scale) +class W4A8_AWQ_LinearMethod(LinearMethodBase): + + def create_weights(self, module: Linear, in_features: int, + out_features: int, bias: bool, dtype: torch.dtype): + # Quantized weights + module.weight = Parameter(torch.empty( + (in_features, out_features // 2), + dtype=torch.int8, + ), + requires_grad=False) + + group_size = module.quant_config.group_size + if in_features % group_size != 0: + raise ValueError( + f"in_features ({module.in_features}) must be divisible by group_size ({group_size}) " + f"for INT4 per-group quantization scale dimensions.") + + # NOTE: for FP8 activation, scales needs to be float16 + module.weight_scale = Parameter(torch.empty( + (in_features // group_size, out_features), dtype=torch.float16), + requires_grad=False) + + # Similar to W4A16 AWQ, not all linears will have this tensor + module.pre_quant_scale = None + + module.input_scale = Parameter(torch.tensor(1., dtype=torch.float32), + requires_grad=False) + module.inv_input_scale = Parameter(torch.tensor(1., + dtype=torch.float32), + requires_grad=False) + + module.alpha = Parameter(torch.empty([1], dtype=torch.float32), + requires_grad=False) + + if bias: + module.bias = Parameter(torch.empty((out_features), dtype=dtype), + requires_grad=False) + else: + module.register_parameter("bias", None) + + def apply(self, module: Linear, input: torch.Tensor, + bias: Optional[torch.Tensor]): + """ + modelopt flow for w4a8_awq: + 1. multiply pre_quant_scale to input + 2. quantize input to fp8 using input_scale + 3. unpack_weights and multiply by weight_scales (int4 -> fp16) + 4. divied by weight_scale_2 (fp16 -> fp8 to allow gemm in fp8). + 5. apply gemm in fp8. + 6. rescale using alpha which is input_scale * weight_scale_2 + """ + if module.pre_quant_scale is not None: + input = input * module.pre_quant_scale + + if input.dtype == torch.float8_e4m3fn: + quantized_input = input + else: + quantized_input, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + input, (module.input_scale)) + + bias = bias.contiguous() if bias is not None else None + + output = torch.ops.trtllm.finegrained_mixed_dtype_gemm( + input=quantized_input.contiguous(), + weight=module.weight, + scales=module.weight_scale, + group_size=module.quant_config.group_size, + has_zero_point=module.quant_config.has_zero_point, + output_dtype=module.dtype + or input.dtype, # NOTE: output_dtype can only be bf16/fp16 for W4A8 + alpha=module.alpha.item(), + bias=bias, + zeros=None) + + return output + + def load_weight_scales_w4a8(self, + weights: List[Dict], + tp_size: int = 1, + tp_rank: int = 0, + tp_mode: Optional[TensorParallelMode] = None): + # For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared. + input_scale = None + weight_scale_2 = None + weight_scale = [] + + device = torch.device("cuda") + + for w in weights: + if "input_scale" in w: + if input_scale is None: + input_scale = w["input_scale"][...] + else: + assert input_scale == w["input_scale"][ + ...], "The input_scale should be same for all the weights" + if "weight_scale" in w: + ws = load_weight_shard(w["weight_scale"], + tp_size, + tp_rank, + tp_mode, + device=device) + + weight_scale.append(ws.to(torch.float16)) + if "weight_scale_2" in w: + if weight_scale_2 is None: + weight_scale_2 = w["weight_scale_2"][...] + else: + assert weight_scale_2 == w["weight_scale_2"][ + ...], "The weight_scale_2 should be same for all the weights" + + # Compute scaling factor and alpha required by GEMM kernels (rescale the gemm output in fp8) + alpha = (input_scale.float() * weight_scale_2.float()) + + return input_scale, weight_scale, alpha, weight_scale_2 + + def load_weights_vanilla(self, module: Linear, weights: List[Dict]): + load_weights_vanilla_helper(module, weights) + + device = torch.device('cuda') + pre_quant_scale = load_weight_shard( + weights[0]["pre_quant_scale"], + module.tp_size, + module.tp_rank, + # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around + TensorParallelMode.flip(module.tp_mode), + device, + ) + + assert pre_quant_scale.dtype == module.dtype + + module.pre_quant_scale = Parameter( + torch.empty((module.in_features, ), dtype=pre_quant_scale.dtype), + requires_grad=False).to(device=device) + + copy_weight(module.pre_quant_scale, pre_quant_scale) + + input_scale, weight_scale, alpha, weight_scale_2 = self.load_weight_scales_w4a8( + weights=weights, + tp_size=module.tp_size, + tp_rank=module.tp_rank, + tp_mode=module.tp_mode) + + assert len(weight_scale) == 1, "there should be only one weight scale" + + weight_scale = (weight_scale[0].T / weight_scale_2).contiguous() + + copy_weight(module.weight_scale, weight_scale) + copy_weight(module.input_scale, input_scale) + copy_weight(module.alpha, alpha) + + module.inv_input_scale.data = 1.0 / module.input_scale + + def load_weights_fused_qkv_linear(self, module: Linear, + weights: List[Dict]): + + q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( + module, weights) + + fused_weight = torch.cat((q_weight, k_weight, v_weight)) + fused_weight = preprocess_weights_for_mixed_gemm( + fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2, + torch.float8_e4m3fn).cuda().contiguous() + + copy_weight(module.weight, fused_weight) + + input_scale, weight_scales, alpha, weight_scale_2 = self.load_weight_scales_w4a8( + weights=weights, + tp_size=module.tp_size, + tp_rank=module.tp_rank, + tp_mode=module.tp_mode) + + # Create concatenated weight scale tensor + cat_weight_scale = (torch.cat(weight_scales, dim=0).T / + weight_scale_2).contiguous() + copy_weight(module.weight_scale, cat_weight_scale) + copy_weight(module.input_scale, input_scale) + copy_weight(module.alpha, alpha) + + # NOTE: pre_quant_scale is the same for q,k,v since modelopt checks which layer shared the same input and create an avg pre_quant_scale + # Usually when modelopt exports the quantized model, pre_quant_Scale is fused in the layer norm (this case relevant if fused is disabled - modelopt internal) + if "pre_quant_scale" in weights[0].keys(): + + pre_quant_scale = load_weight_shard( + weights[0]["pre_quant_scale"], + module.tp_size, + module.tp_rank, + # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around + TensorParallelMode.flip(module.tp_mode), + torch.device('cuda'), + ) + + module.pre_quant_scale = Parameter( + torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype), + requires_grad=False).to(device=torch.device('cuda')) + + copy_weight(module.pre_quant_scale, pre_quant_scale) + + def load_weights_fused_gate_up_linear(self, module: Linear, + weights: List[Dict]): + + gate_weight, up_weight = load_weights_fused_gate_up_helper( + module, weights) + + fused_weight = torch.cat((gate_weight, up_weight)) + fused_weight = preprocess_weights_for_mixed_gemm( + fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2, + torch.float8_e4m3fn).cuda().contiguous() + + copy_weight(module.weight, fused_weight) + + input_scale, weight_scale, alpha, weight_scale_2 = self.load_weight_scales_w4a8( + weights=weights, + tp_size=module.tp_size, + tp_rank=module.tp_rank, + tp_mode=module.tp_mode) + + fused_scale = (torch.cat(weight_scale, dim=0).T / + weight_scale_2).contiguous() + copy_weight(module.weight_scale, fused_scale) + copy_weight(module.input_scale, input_scale) + copy_weight(module.alpha, alpha) + + if "pre_quant_scale" in weights[0].keys(): + pre_quant_scale = load_weight_shard( + weights[0]["pre_quant_scale"], + module.tp_size, + module.tp_rank, + # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around + TensorParallelMode.flip(module.tp_mode), + torch.device('cuda'), + ) + + # NOTE:Create this tensor in load_weights, since not all layer have this tensor and memory is not allocated for it (same as W4A16) + module.pre_quant_scale = Parameter( + torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype), + requires_grad=False).to(device=torch.device('cuda')) + + copy_weight(module.pre_quant_scale, pre_quant_scale) + + def get_quant_method(quant_config: Optional[QuantConfig] = None): if quant_config is None or not quant_config.layer_quant_mode.has_any_quant( exclude_kv_cache=True): @@ -1027,6 +1281,9 @@ def get_quant_method(quant_config: Optional[QuantConfig] = None): if quant_config.layer_quant_mode.is_int4_weight_only_per_group( ) and quant_config.quant_algo == QuantAlgo.W4A16_AWQ: return W4A16_AWQ_LinearMethod() + if quant_config.layer_quant_mode.is_int4_weight_only_per_group( + ) and quant_config.quant_algo == QuantAlgo.W4A8_AWQ: + return W4A8_AWQ_LinearMethod() raise ValueError(f'unsupported quant mode: {quant_config.quant_mode}') @@ -1151,6 +1408,12 @@ def has_w4a16_awq(self): return self.quant_config is not None and self.quant_config.layer_quant_mode.is_int4_weight_only_per_group( ) and self.quant_config.quant_algo == QuantAlgo.W4A16_AWQ + @property + def has_w4a8_awq(self): + assert self._weights_created + return self.quant_config is not None and self.quant_config.layer_quant_mode.is_int4_weight_only_per_group( + ) and self.quant_config.quant_algo == QuantAlgo.W4A8_AWQ + def apply_linear(self, input, bias, diff --git a/tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py b/tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py new file mode 100644 index 00000000000..0041f11da6b --- /dev/null +++ b/tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py @@ -0,0 +1,122 @@ +import pytest +import torch +from utils.util import woq_assert_near_eq, woq_groupwise_gt_matmul + +import tensorrt_llm +from tensorrt_llm._torch.custom_ops.torch_custom_ops import \ + FinegrainedMixedDtypeGemm +from tensorrt_llm._utils import get_sm_version + + +@pytest.mark.parametrize( + "m, n, k, group_size, activation_dtype, has_pre_quant, has_zero, has_bias, use_w4a8_awq", + [ + (3, 1024, 64, 64, torch.bfloat16, True, False, True, False), + (128, 1024, 256, 64, torch.bfloat16, True, False, True, False), + (192, 2048, 384, 64, torch.bfloat16, True, False, True, False), + (256, 2048, 1024, 64, torch.bfloat16, True, False, True, False), + (4, 1024, 128, 128, torch.bfloat16, True, False, True, False), + (64, 1024, 256, 128, torch.bfloat16, True, False, True, False), + (384, 2048, 384, 128, torch.bfloat16, True, False, True, False), + (512, 2048, 1024, 128, torch.bfloat16, True, False, True, False), + (4, 1024, 128, 128, torch.bfloat16, True, True, True, False), + (64, 1024, 256, 128, torch.bfloat16, True, True, True, False), + (384, 2048, 384, 128, torch.bfloat16, True, True, True, False), + (512, 2048, 1024, 128, torch.bfloat16, True, True, False, False), + (3, 1024, 64, 64, torch.float16, True, False, True, False), + (128, 1024, 256, 64, torch.float16, True, False, True, False), + (192, 2048, 384, 64, torch.float16, True, False, True, False), + (256, 2048, 1024, 64, torch.float16, True, False, True, False), + (4, 1024, 128, 128, torch.float16, True, False, True, False), + (64, 1024, 256, 128, torch.float16, True, False, True, False), + (384, 2048, 384, 128, torch.float16, True, False, True, False), + (512, 2048, 1024, 128, torch.float16, True, False, True, False), + (4, 1024, 128, 128, torch.float16, True, True, True, False), + (64, 1024, 256, 128, torch.float16, True, True, True, False), + (384, 2048, 384, 128, torch.float16, True, True, True, False), + (512, 2048, 1024, 128, torch.float16, True, True, False, False), + (512, 2048, 1024, 128, torch.bfloat16, True, False, True, True), + (4, 1024, 128, 128, torch.bfloat16, True, True, True, True), + (64, 1024, 256, 128, torch.bfloat16, True, True, True, True), + (384, 2048, 384, 128, torch.bfloat16, True, True, True, True), + (512, 2048, 1024, 128, torch.bfloat16, True, True, False, True), + (128, 1024, 256, 128, torch.float16, True, False, True, True), + (192, 2048, 384, 128, torch.float16, True, False, True, True), + (256, 2048, 1024, 128, torch.float16, True, False, True, True), + (4, 1024, 128, 128, torch.float16, True, False, True, True), + ]) +def test_matmul_activation_int4_input(m, n, k, group_size, activation_dtype, + has_pre_quant, has_zero, has_bias, + use_w4a8_awq): + torch.manual_seed(0) + device = "cuda" + + if get_sm_version() > FinegrainedMixedDtypeGemm.MAX_SUPPORTED_SM_VERSION: + pytest.skip( + f"W4A16/W4A8 not supported for SM version {get_sm_version()}") + + total_groups = (k + group_size - 1) // group_size + scale_zero_dtype = torch.float16 if use_w4a8_awq else activation_dtype + activation = torch.randn(m, k, dtype=activation_dtype, device=device) + scale = torch.rand(total_groups, n, dtype=scale_zero_dtype, device=device) + zero = torch.randn(total_groups, n, dtype=scale_zero_dtype, + device=device) if has_zero else None + pre_quant_scale = torch.rand(1, k, dtype=activation_dtype, device=device) + bias = torch.randn(1, n, dtype=activation_dtype, + device=device) if has_bias else None + fp8_alpha = torch.rand(1, dtype=torch.float32, + device="cuda") if use_w4a8_awq else None + + num_weights_in_32_bits = 8 # for torch.quint4x2 + unprocessed_int_weight = torch.randint(-2**31, + 2**31, + (k, n // num_weights_in_32_bits), + dtype=torch.int32, + device=device) + unprocessed_weight = unprocessed_int_weight.view(torch.int8) + + if use_w4a8_awq: + activation_type = torch.float8_e4m3fn + else: + activation_type = activation_dtype + + # Ref quantized weights + unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 + ref_q_weight = unpacker(unprocessed_weight.cpu()).contiguous().cuda() + + cuda_q_weight = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm( + unprocessed_weight.cpu(), torch.quint4x2, + activation_type).cuda().contiguous() + + scale_ref = scale.repeat_interleave(group_size, dim=0)[:k, :] + ref_th_weight = ref_q_weight.to(activation_dtype) * scale_ref + + if has_zero: + zero_ref = zero.repeat_interleave(group_size, dim=0)[:k, :] + ref_th_weight += zero_ref + + if has_pre_quant: + pre_quant_scale = pre_quant_scale.repeat(m, 1) + activation = torch.mul(activation, pre_quant_scale) + + output = torch.ops.trtllm.finegrained_mixed_dtype_gemm( + input=activation.to(activation_type).contiguous() + if use_w4a8_awq else activation.contiguous(), + weight=cuda_q_weight, + scales=scale.contiguous(), + group_size=group_size, + has_zero_point=has_zero, + output_dtype= + activation_dtype, # NOTE: output_dtype needs to match activation dtype for W4A16. + # where in W4A8 output dtype is float16/bfloat16 where activation dtype is float8_e4m3fn + alpha=fp8_alpha.item() if use_w4a8_awq else None, + bias=bias.contiguous() if has_bias else None, + zeros=zero) + + if use_w4a8_awq: + activation *= fp8_alpha + + ref = woq_groupwise_gt_matmul(activation, + ref_th_weight.to(activation_dtype), bias) + + woq_assert_near_eq(ref, output, 2) diff --git a/tests/unittest/_torch/thop/test_w4a16_gemm.py b/tests/unittest/_torch/thop/test_w4a16_gemm.py deleted file mode 100644 index b3a034bd5d7..00000000000 --- a/tests/unittest/_torch/thop/test_w4a16_gemm.py +++ /dev/null @@ -1,94 +0,0 @@ -import pytest -import torch -from utils.util import woq_assert_near_eq, woq_groupwise_gt_matmul - -import tensorrt_llm -from tensorrt_llm._torch.custom_ops.torch_custom_ops import W4A16GemmRunner -from tensorrt_llm._utils import get_sm_version - - -@pytest.mark.parametrize( - "m, n, k, group_size, activation_dtype, has_pre_quant, has_zero, has_bias", - [ - (3, 1024, 64, 64, torch.bfloat16, True, False, True), - (128, 1024, 256, 64, torch.bfloat16, True, False, True), - (192, 2048, 384, 64, torch.bfloat16, True, False, True), - (256, 2048, 1024, 64, torch.bfloat16, True, False, True), - (4, 1024, 128, 128, torch.bfloat16, True, False, True), - (64, 1024, 256, 128, torch.bfloat16, True, False, True), - (384, 2048, 384, 128, torch.bfloat16, True, False, True), - (512, 2048, 1024, 128, torch.bfloat16, True, False, True), - (4, 1024, 128, 128, torch.bfloat16, True, True, True), - (64, 1024, 256, 128, torch.bfloat16, True, True, True), - (384, 2048, 384, 128, torch.bfloat16, True, True, True), - (512, 2048, 1024, 128, torch.bfloat16, True, True, False), - (3, 1024, 64, 64, torch.float16, True, False, True), - (128, 1024, 256, 64, torch.float16, True, False, True), - (192, 2048, 384, 64, torch.float16, True, False, True), - (256, 2048, 1024, 64, torch.float16, True, False, True), - (4, 1024, 128, 128, torch.float16, True, False, True), - (64, 1024, 256, 128, torch.float16, True, False, True), - (384, 2048, 384, 128, torch.float16, True, False, True), - (512, 2048, 1024, 128, torch.float16, True, False, True), - (4, 1024, 128, 128, torch.float16, True, True, True), - (64, 1024, 256, 128, torch.float16, True, True, True), - (384, 2048, 384, 128, torch.float16, True, True, True), - (512, 2048, 1024, 128, torch.float16, True, True, False), - ]) -def test_matmul_activation_int4_input(m, n, k, group_size, activation_dtype, - has_pre_quant, has_zero, has_bias): - torch.manual_seed(0) - device = "cuda" - - if get_sm_version() > W4A16GemmRunner.MAX_SUPPORTED_SM_VERSION: - pytest.skip(f"W4A16 not supported for SM version {get_sm_version()}") - - total_groups = (k + group_size - 1) // group_size - activation = torch.randn(m, k, dtype=activation_dtype, device=device) - scale = torch.rand(total_groups, n, dtype=activation_dtype, device=device) - zero = torch.randn(total_groups, n, dtype=activation_dtype, - device=device) if has_zero else None - pre_quant_scale = torch.rand(1, k, dtype=activation_dtype, device=device) - bias = torch.randn(1, n, dtype=activation_dtype, - device=device) if has_bias else None - - num_weights_in_32_bits = 8 # for torch.quint4x2 - unprocessed_int_weight = torch.randint(-2**31, - 2**31, - (k, n // num_weights_in_32_bits), - dtype=torch.int32, - device=device) - unprocessed_weight = unprocessed_int_weight.view(torch.int8) - - # Ref quantized weights - unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 - ref_q_weight = unpacker(unprocessed_weight.cpu()).contiguous().cuda() - - cuda_q_weight = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm( - unprocessed_weight.cpu(), torch.quint4x2, - activation_dtype).cuda().contiguous() - - scale_ref = scale.repeat_interleave(group_size, dim=0)[:k, :] - ref_th_weight = ref_q_weight.to(activation_dtype) * scale_ref - - if has_zero: - zero_ref = zero.repeat_interleave(group_size, dim=0)[:k, :] - ref_th_weight += zero_ref - - if has_pre_quant: - pre_quant_scale = pre_quant_scale.repeat(m, 1) - activation = torch.mul(activation, pre_quant_scale) - - output = torch.ops.trtllm.w4a16_gemm( - activation.contiguous(), - cuda_q_weight, - scale.contiguous(), - group_size, - has_zero, - bias.contiguous() if has_bias else None, - zeros=zero) - - ref = woq_groupwise_gt_matmul(activation, - ref_th_weight.to(activation_dtype), bias) - - woq_assert_near_eq(ref, output, 2) diff --git a/tests/unittest/_torch/thop/test_w4a16_linear.py b/tests/unittest/_torch/thop/test_w4a16_linear.py index 1398acc2971..8aac068211a 100644 --- a/tests/unittest/_torch/thop/test_w4a16_linear.py +++ b/tests/unittest/_torch/thop/test_w4a16_linear.py @@ -3,7 +3,8 @@ import tensorrt_llm.quantization.functional from tensorrt_llm._torch.autotuner import autotune -from tensorrt_llm._torch.custom_ops.torch_custom_ops import W4A16GemmRunner +from tensorrt_llm._torch.custom_ops.torch_custom_ops import \ + FinegrainedMixedDtypeGemm from tensorrt_llm._torch.modules.linear import Linear from tensorrt_llm._utils import get_sm_version from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig @@ -16,9 +17,10 @@ ) def test_w4a16_linear(dtype, weights_dtype, has_zero=False): - if get_sm_version() > W4A16GemmRunner.MAX_SUPPORTED_SM_VERSION: + if get_sm_version() > FinegrainedMixedDtypeGemm.MAX_SUPPORTED_SM_VERSION: pytest.skip( - f"W4A116 is not supported in this SM version {get_sm_version()}") + f"W4A16/W4A8 is not supported in this SM version {get_sm_version()}" + ) SEQ_LEN = 10 HIDDEN_SIZE = 128 @@ -72,12 +74,14 @@ def test_w4a16_linear(dtype, weights_dtype, has_zero=False): pre_quant_scale = pre_quant_scale.repeat(SEQ_LEN, 1) x = torch.mul(x, pre_quant_scale) - output_ref = torch.ops.trtllm.w4a16_gemm(x.contiguous(), - w, - weight_scale.type(x.dtype), - GROUP_SIZE, - has_zero, - bias, - zeros=None) + output_ref = torch.ops.trtllm.finegrained_mixed_dtype_gemm( + input=x.contiguous(), + weight=w, + scales=weight_scale.type(x.dtype), + group_size=GROUP_SIZE, + has_zero_point=has_zero, + bias=bias, + output_dtype=x.dtype, + zeros=None) torch.cuda.synchronize() torch.testing.assert_close(output, output_ref) diff --git a/tests/unittest/_torch/thop/test_w4a8_linear.py b/tests/unittest/_torch/thop/test_w4a8_linear.py new file mode 100644 index 00000000000..20187385a6d --- /dev/null +++ b/tests/unittest/_torch/thop/test_w4a8_linear.py @@ -0,0 +1,100 @@ +import pytest +import torch +from torch.nn.parameter import Parameter + +import tensorrt_llm.quantization.functional +from tensorrt_llm._torch.autotuner import autotune +from tensorrt_llm._torch.custom_ops.torch_custom_ops import \ + FinegrainedMixedDtypeGemm +from tensorrt_llm._torch.modules.linear import Linear +from tensorrt_llm._utils import get_sm_version +from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig + + +@pytest.mark.parametrize("weights_dtype", [torch.uint8]) +@pytest.mark.parametrize( + "dtype", + [torch.float16], +) +def test_w4a8_linear(dtype, weights_dtype, has_zero=False): + + if get_sm_version() > FinegrainedMixedDtypeGemm.MAX_SUPPORTED_SM_VERSION: + pytest.skip( + f"W4A16/W4A8 is not supported in this SM version {get_sm_version()}" + ) + + SEQ_LEN = 10 + HIDDEN_SIZE = 128 + OUTPUT_SIZE = 512 + GROUP_SIZE = 128 + torch.manual_seed(0) + + total_groups = (HIDDEN_SIZE + GROUP_SIZE - 1) // GROUP_SIZE + + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() + w = torch.randint(0, + 2**32 - 1, (HIDDEN_SIZE, OUTPUT_SIZE // 8), + dtype=torch.uint32, + device=x.device) + w = w.view(weights_dtype) + + pre_quant_scale = torch.rand(HIDDEN_SIZE, dtype=dtype).cuda() + weight_scale = torch.rand(total_groups, OUTPUT_SIZE, + dtype=torch.float16).cuda() + weight_scale_2 = torch.rand(1, dtype=torch.float32).cuda() + input_scale = Parameter(torch.tensor(1., dtype=torch.float32), + requires_grad=False).cuda() + bias = torch.randn(OUTPUT_SIZE, dtype=dtype).cuda().contiguous() + + qc = QuantConfig(quant_algo=QuantAlgo.W4A8_AWQ, + group_size=GROUP_SIZE, + has_zero_point=has_zero) + + linear_w4a8 = Linear(in_features=HIDDEN_SIZE, + out_features=OUTPUT_SIZE, + bias=True, + dtype=dtype, + quant_config=qc) + + linear_w4a8.load_weights([{ + 'pre_quant_scale': pre_quant_scale, + 'weight': w.T.clone(), + 'weight_scale': weight_scale.T, + 'bias': bias, + 'weight_scale_2': weight_scale_2, + 'input_scale': input_scale + }]) + + linear_w4a8 = linear_w4a8.cuda() + + preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm + w = preprocessor( + w.to(torch.int8).contiguous().cpu(), torch.quint4x2, + torch.float8_e4m3fn).cuda().contiguous() + + torch.testing.assert_close(linear_w4a8.weight, w) + + with torch.inference_mode(), autotune(): + output = linear_w4a8.forward(x) + + # ref linear + with torch.inference_mode(): + x = x * pre_quant_scale + + quantized_input, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + x, (input_scale)) + alpha = (weight_scale_2.float() * input_scale.float()).item() + + output_ref = torch.ops.trtllm.finegrained_mixed_dtype_gemm( + input=quantized_input.contiguous(), + weight=w.contiguous(), + scales=(weight_scale / weight_scale_2).to( + torch.float16).contiguous(), + group_size=GROUP_SIZE, + has_zero_point=has_zero, + output_dtype=x.dtype, + alpha=alpha, + bias=bias, + zeros=None) + torch.cuda.synchronize() + torch.testing.assert_close(output, output_ref)