diff --git a/test/prototype/inductor/test_qsdpa_fusion.py b/test/prototype/inductor/test_qsdpa_fusion.py index dc754d2682..eef7694049 100644 --- a/test/prototype/inductor/test_qsdpa_fusion.py +++ b/test/prototype/inductor/test_qsdpa_fusion.py @@ -18,26 +18,184 @@ from torchao.utils import torch_version_at_least -class SelfAttnLikeModule(torch.nn.Module): +def qdq(input, scale): + dtype = input.dtype + q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( + input, + torch.tensor([scale]), + torch.float8_e4m3fn, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + q_input, + torch.tensor([scale]), + dtype, + ) + return dq_input + + +def fp8_convert_(model): + def generate_model_info(model): + from collections import namedtuple + + mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"]) + parent_child_mod_dict = {} + + def create_mod_info_recursion(parent): + for name, mod in parent.named_children(): + parent_child_mod_dict[mod] = mod_inst_info(name=name, parent=parent) + create_mod_info_recursion(mod) + + create_mod_info_recursion(model) + return parent_child_mod_dict + + parent_child_mod_dict = generate_model_info(model) + for name, mod in model.named_modules(): + mod_type_str = mod.__class__.__name__ + if mod_type_str not in [ + "Linear", + "SDPA", + ]: + continue + if mod_type_str == "Linear": + param = mod.weight + xmax = torch.max(param) + weight_scale = xmax / torch.finfo(torch.float8_e4m3fn).max + mod.weight_scale = weight_scale + q_param = torch.clamp( + (param / weight_scale), + torch.finfo(torch.float8_e4m3fn).min, + torch.finfo(torch.float8_e4m3fn).max, + ).to(torch.float8_e4m3fn) + mod.weight.data = q_param + patched_mod = FP8QDQLinear(mod.in_features, mod.out_features, False) + patched_mod.bias = mod.bias + patched_mod.weight_scale = weight_scale.item() + patched_mod.weight.data = q_param + else: + patched_mod = FP8QDQSDPA() + patched_mod.__dict__.update(mod.__dict__) + patched_mod.transpose_for_scores = mod.transpose_for_scores + + patched_mod.q_out_scale = ( + patched_mod.q_out_scale / torch.finfo(torch.float8_e4m3fn).max + ) + patched_mod.k_out_scale = ( + patched_mod.k_out_scale / torch.finfo(torch.float8_e4m3fn).max + ) + patched_mod.attn_weights_scale = ( + patched_mod.attn_weights_scale / torch.finfo(torch.float8_e4m3fn).max + ) + patched_mod.v_out_scale = ( + patched_mod.v_out_scale / torch.finfo(torch.float8_e4m3fn).max + ) + patched_mod.qk_out_scale = ( + patched_mod.qk_out_scale / torch.finfo(torch.float8_e4m3fn).max + ) + patched_mod.attn_out_scale = ( + patched_mod.attn_out_scale / torch.finfo(torch.float8_e4m3fn).max + ) + + parent = parent_child_mod_dict[mod].parent + name = parent_child_mod_dict[mod].name + setattr(parent, name, patched_mod) + model.eval() + return model + + +class FP8QDQLinear(torch.nn.Module): + def __init__(self, in_features, out_features, has_bias): + super().__init__() + self.qtype = torch.float8_e4m3fn + self.weight = torch.randn((out_features, in_features)).to(self.qtype) + self.weight_scale = 2.0 + self.scale = 2.0 + self.bias = None + if has_bias: + self.bias = torch.randn((out_features,)) + + def forward(self, input): + weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=self.weight.data, + scale=torch.tensor([self.weight_scale]), + output_dtype=torch.float, + ) + + q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( + tensor=input, + scale=torch.tensor([self.scale]), + float8_dtype=self.qtype, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=q_input, + scale=torch.tensor([self.scale]), + output_dtype=torch.float, + ) + + out = torch.nn.functional.linear(dq_input, weight, self.bias) + return out + + +class FP8QDQSDPA(torch.nn.Module): + def __init__(self): + super().__init__() + self.q_out_scale = 1.5 + self.k_out_scale = 1.5 + self.attn_weights_scale = 1.5 + self.v_out_scale = 1.5 + self.attn_out_scale = 1.5 + self.qk_out_scale = 1.5 + + def forward(self, q, k, v, mask): + key = self.transpose_for_scores(q) + value = self.transpose_for_scores(k) + query = self.transpose_for_scores(v) + + # Take the dot product between "query" and "key" to get the raw attention scores. + query_qdq = qdq(query, self.q_out_scale) + key_qdq = qdq(key.transpose(-1, -2), self.k_out_scale) + attn_weights = torch.matmul(query_qdq, key_qdq) / (self.input_dim**0.5) + + # Normalize the attention scores to probabilities. + attn_weights = torch.nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + dropout = 0.0 if not self.training else self.dropout_prob + attn_weights = torch.nn.functional.dropout( + attn_weights, p=dropout, training=self.training + ) + + # Mask heads if we want to + if mask is not None: + attn_weights = attn_weights + mask + + value_qdq = qdq(value, self.v_out_scale) + attn_weights_qdq = qdq(attn_weights, self.attn_weights_scale) + attn_output = torch.matmul(attn_weights_qdq, value_qdq) + attn_output = attn_output.transpose(1, 2).contiguous() + + new_context_layer_shape = attn_output.size()[:-2] + (self.all_head_size,) + attn_output = attn_output.reshape(new_context_layer_shape) + + return attn_output + + +class SDPA(torch.nn.Module): def __init__( self, input_dim, has_mask, - num_attention_heads=None, - attention_head_size=None, + num_attention_heads, + attention_head_size, ) -> None: super().__init__() self.input_dim = input_dim - self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) - self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) - self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) self.softmax = torch.nn.Softmax(dim=-1) - assert num_attention_heads is not None - assert attention_head_size is not None self.num_attention_heads = num_attention_heads self.attention_head_size = attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size - self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size) self.dropout = torch.nn.Dropout(0) self.has_mask = has_mask @@ -49,10 +207,7 @@ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: x = x.view(new_x_shape) return x.permute([0, 2, 1, 3]) - def forward(self, x, mask): - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) + def forward(self, q, k, v, mask): q = self.transpose_for_scores(q) k = self.transpose_for_scores(k) v = self.transpose_for_scores(v) @@ -63,9 +218,38 @@ def forward(self, x, mask): attention = self.dropout(attention) context_layer = torch.matmul(attention, v) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - context_layer = context_layer.view( - context_layer.size()[:-2] + (self.all_head_size,) + return context_layer.reshape(context_layer.size()[:-2] + (self.all_head_size,)) + + +class MHAModule(torch.nn.Module): + def __init__( + self, + input_dim, + has_mask, + num_attention_heads, + attention_head_size, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size) + self.attn_mod = SDPA( + input_dim, + has_mask, + num_attention_heads, + attention_head_size, ) + + def forward(self, x, mask): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + context_layer = self.attn_mod(q, k, v, mask) return self.dense(context_layer) @@ -158,7 +342,7 @@ def _check_common( reason="cpp kernels not built", ) @config.patch({"freezing": True}) - def _test_qsdpa_rewriter(self): + def _test_int8_sdpa_rewriter(self): import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( @@ -171,7 +355,7 @@ def _test_qsdpa_rewriter(self): [torch.float32, torch.bfloat16], [True, False], [56, 1] ): seqlen, numhead, headsize = 197, 16, 64 - mod = SelfAttnLikeModule( + mod = MHAModule( input_dim=headsize * numhead, has_mask=has_mask, num_attention_heads=numhead, @@ -204,6 +388,51 @@ def _test_qsdpa_rewriter(self): prepare_model(*inputs) convert_model = convert_pt2e(prepare_model) torchao.quantization.pt2e.move_exported_model_to_eval(convert_model) + + self._check_common( + convert_model, args1=inputs, check_train=False, atol=1.0 + ) + + @skipIfRocm + @unittest.skipIf( + not torch_version_at_least("2.7.0"), + reason="qsdpa requires torch 2.7 or later", + ) + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"), + reason="cpp kernels not built", + ) + @config.patch({"freezing": True}) + def _test_fp8_sdpa_rewriter(self): + import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq # noqa: F401 + + # pattern is different for bs=1 + torch.manual_seed(1234) + for dtype, bs in itertools.product([torch.float32, torch.bfloat16], [56, 1]): + seqlen, numhead, headsize = 197, 16, 64 + mod = MHAModule( + input_dim=headsize * numhead, + has_mask=False, + num_attention_heads=numhead, + attention_head_size=headsize, + ).eval() + inputs = ( + torch.randn( + (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype + ), + None, + ) + enable_autocast = dtype == torch.bfloat16 + with ( + torch.no_grad(), + torch.amp.autocast( + self.device, enabled=enable_autocast, dtype=torch.bfloat16 + ), + config.patch(post_grad_custom_pre_pass=custom_pass), + ): + _qsdpa_init() + convert_model = fp8_convert_(mod) + self._check_common( convert_model, args1=inputs, check_train=False, atol=1.0 ) @@ -213,7 +442,12 @@ def _test_qsdpa_rewriter(self): class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): device = "cpu" - test_qsdpa_rewriter_cpu = TestSDPAPatternRewriterTemplate._test_qsdpa_rewriter + test_int8_sdpa_rewriter_cpu = ( + TestSDPAPatternRewriterTemplate._test_int8_sdpa_rewriter + ) + test_fp8_sdpa_rewriter_cpu = ( + TestSDPAPatternRewriterTemplate._test_fp8_sdpa_rewriter + ) if __name__ == "__main__": diff --git a/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py b/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py index 5e495a0623..ef0d94db62 100644 --- a/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py +++ b/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py @@ -28,7 +28,7 @@ ] aten = torch.ops.aten -quantize_dtypes = [torch.uint8] +quantize_dtypes = [torch.uint8, torch.float8_e4m3fn] def _is_valid_qsdpa_pattern(): @@ -121,31 +121,53 @@ def qsdpa(match: Match, *args, **kwargs): def _generate_dequant_pattern( input_pattern, qtype, is_reduced_type, scale: str, zp: str = None ): - assert qtype is torch.uint8, "QSDPA expects type to be uint8" - assert zp is not None, "Zero point must be provided for uint8 dequantization" - return CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - input_pattern, - KeywordArg(scale), - KeywordArg(zp), - Arg(), - Arg(), - Arg(), - ) + if qtype == torch.uint8: + assert zp is not None, "Zero point must be provided for uint8 dequantization" + return CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + input_pattern, + KeywordArg(scale), + KeywordArg(zp), + Arg(), + Arg(), + Arg(), + ) + else: + assert zp is None, "Fp8 dequantization does not support zero point" + if is_reduced_type: + return CallFunction( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + input_pattern, + KeywordArg(scale), + Arg(), + ) + else: + return CallFunction( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + input_pattern, + KeywordArg(scale), + ) def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str = None): - assert qtype is torch.uint8, "QSDPA expects type to be uint8" - assert zp is not None, "Zero point must be provided for uint8 quantization" - return CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - input_pattern, - KeywordArg(scale), - KeywordArg(zp), - Arg(), - Arg(), - Arg(), - ) + if qtype == torch.uint8: + assert zp is not None, "Zero point must be provided for uint8 quantization" + return CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + input_pattern, + KeywordArg(scale), + KeywordArg(zp), + Arg(), + Arg(), + Arg(), + ) + else: + assert zp is None, "Fp8 quantization does not support zero point" + return CallFunction( + torch.ops.torchao.quantize_affine_float8_non_decomposed.default, + input_pattern, + KeywordArg(scale), + ) def _get_qsdpa_qkv_pattern(