|
47 | 47 | from torchao.float8.inference import Float8MMConfig |
48 | 48 | aten = torch.ops.aten |
49 | 49 |
|
50 | | - |
51 | 50 | ############################### |
52 | 51 | # Base Layout Tensor Subclass # |
53 | 52 | ############################### |
@@ -472,6 +471,13 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: |
472 | 471 | temp.view(-1, 4).scatter_(1, pruning_inds, value=0) |
473 | 472 | return temp |
474 | 473 |
|
| 474 | +@dataclass(frozen=True) |
| 475 | +class BlockSparseLayoutType(LayoutType): |
| 476 | + |
| 477 | + def pre_process(self, input: torch.Tensor) -> torch.Tensor: |
| 478 | + return input |
| 479 | + |
| 480 | + |
475 | 481 |
|
476 | 482 | @dataclass(frozen=True) |
477 | 483 | class TensorCoreTiledLayoutType(LayoutType): |
@@ -669,6 +675,162 @@ def from_plain( |
669 | 675 | int_data_compressed = torch._cslt_compress(int_data) |
670 | 676 | return cls(int_data_compressed, scale, zero_point, layout_type) |
671 | 677 |
|
| 678 | +@register_layout_cls(BlockSparseLayoutType) |
| 679 | +class BlockSparseAQTLayout(PlainAQTLayout): |
| 680 | + quantized_linear_impl = "block" |
| 681 | + bsr_crow_indices: Optional[torch.Tensor] |
| 682 | + bsr_col_indices: Optional[torch.Tensor] |
| 683 | + bsr_values: Optional[torch.Tensor] |
| 684 | + scale: Optional[torch.Tensor] |
| 685 | + zero_point: Optional[torch.Tensor] |
| 686 | + |
| 687 | + __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] |
| 688 | + |
| 689 | + @staticmethod |
| 690 | + def __new__( # noqa: PYI034 |
| 691 | + cls, |
| 692 | + shape: torch.Size, |
| 693 | + bsr_crow_indices: Optional[torch.Tensor], |
| 694 | + bsr_col_indices: Optional[torch.Tensor], |
| 695 | + bsr_values: Optional[torch.Tensor], |
| 696 | + scale: Optional[torch.Tensor], |
| 697 | + zero_point: Optional[torch.Tensor], |
| 698 | + layout_type: LayoutType, |
| 699 | + requires_grad: bool = False, |
| 700 | + ): |
| 701 | + if bsr_values is None: |
| 702 | + raise ValueError("bsr values must be provided!") |
| 703 | + else: |
| 704 | + previous_tensor = bsr_values |
| 705 | + |
| 706 | + kwargs = { |
| 707 | + "device": previous_tensor.device, |
| 708 | + "dtype": previous_tensor.dtype, |
| 709 | + "layout": previous_tensor.layout, |
| 710 | + "requires_grad": requires_grad, |
| 711 | + } |
| 712 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 713 | + |
| 714 | + def __init__( # noqa: PYI034 |
| 715 | + self, |
| 716 | + shape: torch.Size, |
| 717 | + bsr_crow_indices: Optional[torch.Tensor], |
| 718 | + bsr_col_indices: Optional[torch.Tensor], |
| 719 | + bsr_values: Optional[torch.Tensor], |
| 720 | + scale: Optional[torch.Tensor], |
| 721 | + zero_point: Optional[torch.Tensor], |
| 722 | + layout_type: LayoutType, |
| 723 | + requires_grad: bool = False, |
| 724 | + ): |
| 725 | + self.bsr_crow_indices = bsr_crow_indices |
| 726 | + self.bsr_col_indices = bsr_col_indices |
| 727 | + self.bsr_values = bsr_values |
| 728 | + self.scale = scale |
| 729 | + self.zero_point = zero_point |
| 730 | + self.layout_type = layout_type |
| 731 | + |
| 732 | + def __repr__(self) -> str: # type: ignore[override] |
| 733 | + assert hasattr(self, "shape") |
| 734 | + return f"{self.__class__.__name__}(shape={self.shape})" |
| 735 | + |
| 736 | + def __tensor_flatten__(self): |
| 737 | + inner_tensors = list( |
| 738 | + filter(lambda x: getattr(self, x) is not None, self.__slots__) |
| 739 | + ) |
| 740 | + tensor_meta = (self.shape, self.layout_type, self.requires_grad) |
| 741 | + return inner_tensors, tensor_meta |
| 742 | + |
| 743 | + @classmethod |
| 744 | + def __tensor_unflatten__( |
| 745 | + cls, |
| 746 | + inner_tensors, |
| 747 | + tensor_meta: Tuple[torch.Size, bool], |
| 748 | + outer_size, |
| 749 | + outer_stride, |
| 750 | + ) -> torch.Tensor: |
| 751 | + shape, layout_type, requires_grad = tensor_meta |
| 752 | + return cls( |
| 753 | + shape=shape, |
| 754 | + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), |
| 755 | + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), |
| 756 | + bsr_values=inner_tensors.get("bsr_values", None), |
| 757 | + scale=inner_tensors.get("scale", None), |
| 758 | + zero_point=inner_tensors.get("zero_point", None), |
| 759 | + layout_type=layout_type, |
| 760 | + requires_grad=requires_grad, |
| 761 | + ) |
| 762 | + |
| 763 | + @classmethod |
| 764 | + def from_plain(cls, int_data, scale, zero_point, layout_type): |
| 765 | + bsr_tensor = int_data.to_sparse_bsr(64) |
| 766 | + return cls( |
| 767 | + shape=int_data.shape, |
| 768 | + bsr_crow_indices=bsr_tensor.crow_indices(), |
| 769 | + bsr_col_indices=bsr_tensor.col_indices(), |
| 770 | + bsr_values=bsr_tensor.values(), |
| 771 | + scale=scale, |
| 772 | + zero_point=zero_point, |
| 773 | + layout_type = layout_type, |
| 774 | + requires_grad=False, |
| 775 | + ) |
| 776 | + |
| 777 | + @torch._dynamo.disable |
| 778 | + def get_plain(self): |
| 779 | + # asdf = torch.eye(self.shape[1]).to(self.device) |
| 780 | + # self_bsr = torch.sparse_bsr_tensor( |
| 781 | + # self.crow_indices().to(self.device), |
| 782 | + # self.col_indices().to(self.device), |
| 783 | + # self.values().to(self.device), |
| 784 | + # size=(self.shape[0], self.shape[1])).to(self.dtype) |
| 785 | + # int_data_bsr = bsr_dense_mm(self_bsr, asdf) |
| 786 | + return torch.zeros(self.shape, device=self.device).to(self.dtype), self.scale, self.zero_point |
| 787 | + |
| 788 | + def _apply_fn_to_data(self, func): |
| 789 | + return self.__class__( |
| 790 | + shape = self.shape, |
| 791 | + bsr_crow_indices=func(self.bsr_crow_indices), |
| 792 | + bsr_col_indices=func(self.bsr_col_indices), |
| 793 | + bsr_values=func(self.bsr_values), |
| 794 | + scale=self.scale, |
| 795 | + zero_point=self.zero_point, |
| 796 | + layout_type=self.layout_type, |
| 797 | + requires_grad=self.requires_grad, |
| 798 | + ) |
| 799 | + |
| 800 | + @classmethod |
| 801 | + def __torch_dispatch__(cls, func, types, args, kwargs): |
| 802 | + kwargs = {} if kwargs is None else kwargs |
| 803 | + |
| 804 | + if func is aten.detach.default: |
| 805 | + return return_and_correct_aliasing( |
| 806 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) |
| 807 | + ) |
| 808 | + if func is aten.clone.default: |
| 809 | + return return_and_correct_aliasing( |
| 810 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) |
| 811 | + ) |
| 812 | + if func is aten.t.default: |
| 813 | + """we don't need to repack the weight and just rely on external |
| 814 | + shape being changed and record the status of transpose/no-transpose |
| 815 | + """ |
| 816 | + args[0].transposed = not args[0].transposed |
| 817 | + return return_and_correct_aliasing(func, args, kwargs, args[0]) |
| 818 | + |
| 819 | + if func is aten.crow_indices.default: |
| 820 | + return args[0].bsr_crow_indices.detach() |
| 821 | + |
| 822 | + if func is aten.col_indices.default: |
| 823 | + return args[0].bsr_col_indices.detach() |
| 824 | + |
| 825 | + if func is aten.values.default: |
| 826 | + return args[0].bsr_values.detach() |
| 827 | + |
| 828 | + if func is aten._nnz.default: |
| 829 | + return args[0].bsr_values.shape[0] |
| 830 | + |
| 831 | + raise NotImplementedError( |
| 832 | + f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported" |
| 833 | + ) |
672 | 834 |
|
673 | 835 | @register_layout_cls(MarlinSparseLayoutType) |
674 | 836 | class MarlinSparseAQTLayout(AQTLayout): |
@@ -1221,6 +1383,42 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh |
1221 | 1383 | y += bias |
1222 | 1384 | return y |
1223 | 1385 |
|
| 1386 | +def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): |
| 1387 | + return ( |
| 1388 | + isinstance(input_tensor, AffineQuantizedTensor) and |
| 1389 | + _aqt_is_int8_reduced_range(input_tensor) and |
| 1390 | + isinstance(weight_tensor, AffineQuantizedTensor) and |
| 1391 | + weight_tensor.is_cuda and |
| 1392 | + input_tensor.dtype == weight_tensor.dtype and |
| 1393 | + isinstance(input_tensor.layout_type, PlainLayoutType) and |
| 1394 | + isinstance(weight_tensor.layout_type, BlockSparseLayoutType) and |
| 1395 | + weight_tensor.layout_tensor.quantized_linear_impl == "block" |
| 1396 | + ) |
| 1397 | + |
| 1398 | + |
| 1399 | +def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): |
| 1400 | + x_vals_int8 = input_tensor.layout_tensor.int_data |
| 1401 | + x_scales = input_tensor.layout_tensor.scale |
| 1402 | + w_vals = weight_tensor.layout_tensor |
| 1403 | + w_scales = weight_tensor.layout_tensor.scale |
| 1404 | + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) |
| 1405 | + tmp_t = tmp.t() |
| 1406 | + |
| 1407 | + y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(), |
| 1408 | + w_vals.col_indices(), |
| 1409 | + w_vals.values(), |
| 1410 | + tmp_t, |
| 1411 | + w_scales, |
| 1412 | + x_scales.reshape(-1)) |
| 1413 | + y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) |
| 1414 | + y = y.reshape(*y_shape) |
| 1415 | + |
| 1416 | + # can downcast only at the very end |
| 1417 | + output_dtype = input_tensor.dtype |
| 1418 | + y = y.to(output_dtype) |
| 1419 | + if bias is not None: |
| 1420 | + y += bias |
| 1421 | + return y |
1224 | 1422 | def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): |
1225 | 1423 | return ( |
1226 | 1424 | # input is native bfloat16 tensor |
@@ -1473,6 +1671,7 @@ def _register_aqt_quantized_linear_dispatches(): |
1473 | 1671 | for dispatch_condition, impl in [ |
1474 | 1672 | (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), |
1475 | 1673 | (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), |
| 1674 | + (_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl), |
1476 | 1675 | (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), |
1477 | 1676 | (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), |
1478 | 1677 | (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), |
|
0 commit comments