Skip to content

Commit 9a0e918

Browse files
committed
wip
1 parent 7dff17a commit 9a0e918

File tree

8 files changed

+465
-46
lines changed

8 files changed

+465
-46
lines changed

test/sparsity/test_sparse_api.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from torchao.sparsity import (
99
apply_fake_sparsity,
10+
apply_fake_block_sparsity,
1011
sparsify_,
1112
semi_sparse_weight,
1213
)
@@ -96,5 +97,67 @@ def test_sparse_marlin(self):
9697

9798
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
9899

100+
class TestBlockSparseWeight(TestCase):
101+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
102+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
103+
def test_sparse(self):
104+
input = torch.rand((1024, 1024)).half().cuda()
105+
model = (
106+
nn.Sequential(
107+
nn.Linear(1024, 2048),
108+
nn.Linear(2048, 1024),
109+
)
110+
.half()
111+
.cuda()
112+
)
113+
114+
from torchao.sparsity.utils import create_block_sparse_tensor
115+
M, N = model[0].weight.shape
116+
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
117+
M, N = model[1].weight.shape
118+
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
119+
dense_result = model(input)
120+
121+
from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight
122+
sparsify_(model, block_sparse_weight())
123+
sparse_result = model(input)
124+
125+
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
126+
127+
class TestQuantBlockSparseWeight(TestCase):
128+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
129+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
130+
def test_sparse(self):
131+
input = torch.rand((128, 128)).to(torch.bfloat16).cuda()
132+
model = (
133+
nn.Sequential(
134+
nn.Linear(128, 256),
135+
nn.Linear(256, 128),
136+
)
137+
.to(torch.bfloat16)
138+
.cuda()
139+
)
140+
141+
from torchao.sparsity.utils import create_block_sparse_tensor
142+
M, N = model[0].weight.shape
143+
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) * torch.rand(M, N, dtype=torch.bfloat16).cuda()
144+
print(model[0].weight)
145+
M, N = model[1].weight.shape
146+
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)
147+
print(model[1].weight)
148+
149+
model_copy = copy.deepcopy(model)
150+
151+
quantize_(model_copy, int8_dynamic_activation_int8_weight())
152+
reference = model_copy(input)
153+
154+
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
155+
quantize_(model, int8_dynamic_activation_int8_weight(layout_type=BlockSparseLayoutType(), ))
156+
sparse_result = model(input)
157+
158+
print(reference)
159+
print(sparse_result)
160+
assert torch.allclose(reference, sparse_result, rtol=1e-2, atol=1e-2)
161+
99162
if __name__ == "__main__":
100163
unittest.main()

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
from torchao.float8.inference import Float8MMConfig
4848
aten = torch.ops.aten
4949

50-
5150
###############################
5251
# Base Layout Tensor Subclass #
5352
###############################
@@ -472,6 +471,13 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor:
472471
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
473472
return temp
474473

474+
@dataclass(frozen=True)
475+
class BlockSparseLayoutType(LayoutType):
476+
477+
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
478+
return input
479+
480+
475481

476482
@dataclass(frozen=True)
477483
class TensorCoreTiledLayoutType(LayoutType):
@@ -669,6 +675,162 @@ def from_plain(
669675
int_data_compressed = torch._cslt_compress(int_data)
670676
return cls(int_data_compressed, scale, zero_point, layout_type)
671677

678+
@register_layout_cls(BlockSparseLayoutType)
679+
class BlockSparseAQTLayout(PlainAQTLayout):
680+
quantized_linear_impl = "block"
681+
bsr_crow_indices: Optional[torch.Tensor]
682+
bsr_col_indices: Optional[torch.Tensor]
683+
bsr_values: Optional[torch.Tensor]
684+
scale: Optional[torch.Tensor]
685+
zero_point: Optional[torch.Tensor]
686+
687+
__slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"]
688+
689+
@staticmethod
690+
def __new__( # noqa: PYI034
691+
cls,
692+
shape: torch.Size,
693+
bsr_crow_indices: Optional[torch.Tensor],
694+
bsr_col_indices: Optional[torch.Tensor],
695+
bsr_values: Optional[torch.Tensor],
696+
scale: Optional[torch.Tensor],
697+
zero_point: Optional[torch.Tensor],
698+
layout_type: LayoutType,
699+
requires_grad: bool = False,
700+
):
701+
if bsr_values is None:
702+
raise ValueError("bsr values must be provided!")
703+
else:
704+
previous_tensor = bsr_values
705+
706+
kwargs = {
707+
"device": previous_tensor.device,
708+
"dtype": previous_tensor.dtype,
709+
"layout": previous_tensor.layout,
710+
"requires_grad": requires_grad,
711+
}
712+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
713+
714+
def __init__( # noqa: PYI034
715+
self,
716+
shape: torch.Size,
717+
bsr_crow_indices: Optional[torch.Tensor],
718+
bsr_col_indices: Optional[torch.Tensor],
719+
bsr_values: Optional[torch.Tensor],
720+
scale: Optional[torch.Tensor],
721+
zero_point: Optional[torch.Tensor],
722+
layout_type: LayoutType,
723+
requires_grad: bool = False,
724+
):
725+
self.bsr_crow_indices = bsr_crow_indices
726+
self.bsr_col_indices = bsr_col_indices
727+
self.bsr_values = bsr_values
728+
self.scale = scale
729+
self.zero_point = zero_point
730+
self.layout_type = layout_type
731+
732+
def __repr__(self) -> str: # type: ignore[override]
733+
assert hasattr(self, "shape")
734+
return f"{self.__class__.__name__}(shape={self.shape})"
735+
736+
def __tensor_flatten__(self):
737+
inner_tensors = list(
738+
filter(lambda x: getattr(self, x) is not None, self.__slots__)
739+
)
740+
tensor_meta = (self.shape, self.layout_type, self.requires_grad)
741+
return inner_tensors, tensor_meta
742+
743+
@classmethod
744+
def __tensor_unflatten__(
745+
cls,
746+
inner_tensors,
747+
tensor_meta: Tuple[torch.Size, bool],
748+
outer_size,
749+
outer_stride,
750+
) -> torch.Tensor:
751+
shape, layout_type, requires_grad = tensor_meta
752+
return cls(
753+
shape=shape,
754+
bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None),
755+
bsr_col_indices=inner_tensors.get("bsr_col_indices", None),
756+
bsr_values=inner_tensors.get("bsr_values", None),
757+
scale=inner_tensors.get("scale", None),
758+
zero_point=inner_tensors.get("zero_point", None),
759+
layout_type=layout_type,
760+
requires_grad=requires_grad,
761+
)
762+
763+
@classmethod
764+
def from_plain(cls, int_data, scale, zero_point, layout_type):
765+
bsr_tensor = int_data.to_sparse_bsr(64)
766+
return cls(
767+
shape=int_data.shape,
768+
bsr_crow_indices=bsr_tensor.crow_indices(),
769+
bsr_col_indices=bsr_tensor.col_indices(),
770+
bsr_values=bsr_tensor.values(),
771+
scale=scale,
772+
zero_point=zero_point,
773+
layout_type = layout_type,
774+
requires_grad=False,
775+
)
776+
777+
@torch._dynamo.disable
778+
def get_plain(self):
779+
# asdf = torch.eye(self.shape[1]).to(self.device)
780+
# self_bsr = torch.sparse_bsr_tensor(
781+
# self.crow_indices().to(self.device),
782+
# self.col_indices().to(self.device),
783+
# self.values().to(self.device),
784+
# size=(self.shape[0], self.shape[1])).to(self.dtype)
785+
# int_data_bsr = bsr_dense_mm(self_bsr, asdf)
786+
return torch.zeros(self.shape, device=self.device).to(self.dtype), self.scale, self.zero_point
787+
788+
def _apply_fn_to_data(self, func):
789+
return self.__class__(
790+
shape = self.shape,
791+
bsr_crow_indices=func(self.bsr_crow_indices),
792+
bsr_col_indices=func(self.bsr_col_indices),
793+
bsr_values=func(self.bsr_values),
794+
scale=self.scale,
795+
zero_point=self.zero_point,
796+
layout_type=self.layout_type,
797+
requires_grad=self.requires_grad,
798+
)
799+
800+
@classmethod
801+
def __torch_dispatch__(cls, func, types, args, kwargs):
802+
kwargs = {} if kwargs is None else kwargs
803+
804+
if func is aten.detach.default:
805+
return return_and_correct_aliasing(
806+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
807+
)
808+
if func is aten.clone.default:
809+
return return_and_correct_aliasing(
810+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
811+
)
812+
if func is aten.t.default:
813+
"""we don't need to repack the weight and just rely on external
814+
shape being changed and record the status of transpose/no-transpose
815+
"""
816+
args[0].transposed = not args[0].transposed
817+
return return_and_correct_aliasing(func, args, kwargs, args[0])
818+
819+
if func is aten.crow_indices.default:
820+
return args[0].bsr_crow_indices.detach()
821+
822+
if func is aten.col_indices.default:
823+
return args[0].bsr_col_indices.detach()
824+
825+
if func is aten.values.default:
826+
return args[0].bsr_values.detach()
827+
828+
if func is aten._nnz.default:
829+
return args[0].bsr_values.shape[0]
830+
831+
raise NotImplementedError(
832+
f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported"
833+
)
672834

673835
@register_layout_cls(MarlinSparseLayoutType)
674836
class MarlinSparseAQTLayout(AQTLayout):
@@ -1221,6 +1383,42 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh
12211383
y += bias
12221384
return y
12231385

1386+
def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias):
1387+
return (
1388+
isinstance(input_tensor, AffineQuantizedTensor) and
1389+
_aqt_is_int8_reduced_range(input_tensor) and
1390+
isinstance(weight_tensor, AffineQuantizedTensor) and
1391+
weight_tensor.is_cuda and
1392+
input_tensor.dtype == weight_tensor.dtype and
1393+
isinstance(input_tensor.layout_type, PlainLayoutType) and
1394+
isinstance(weight_tensor.layout_type, BlockSparseLayoutType) and
1395+
weight_tensor.layout_tensor.quantized_linear_impl == "block"
1396+
)
1397+
1398+
1399+
def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias):
1400+
x_vals_int8 = input_tensor.layout_tensor.int_data
1401+
x_scales = input_tensor.layout_tensor.scale
1402+
w_vals = weight_tensor.layout_tensor
1403+
w_scales = weight_tensor.layout_tensor.scale
1404+
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
1405+
tmp_t = tmp.t()
1406+
1407+
y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(),
1408+
w_vals.col_indices(),
1409+
w_vals.values(),
1410+
tmp_t,
1411+
w_scales,
1412+
x_scales.reshape(-1))
1413+
y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1])
1414+
y = y.reshape(*y_shape)
1415+
1416+
# can downcast only at the very end
1417+
output_dtype = input_tensor.dtype
1418+
y = y.to(output_dtype)
1419+
if bias is not None:
1420+
y += bias
1421+
return y
12241422
def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
12251423
return (
12261424
# input is native bfloat16 tensor
@@ -1473,6 +1671,7 @@ def _register_aqt_quantized_linear_dispatches():
14731671
for dispatch_condition, impl in [
14741672
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
14751673
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
1674+
(_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl),
14761675
(_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl),
14771676
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
14781677
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),

torchao/sparsity/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .utils import PerChannelNormObserver # noqa: F403
99
from .sparse_api import (
1010
apply_fake_sparsity,
11+
apply_fake_block_sparsity,
1112
sparsify_,
1213
semi_sparse_weight,
1314
int8_dynamic_activation_int8_semi_sparse_weight
@@ -17,6 +18,7 @@
1718
"WandaSparsifier",
1819
"PerChannelNormObserver",
1920
"apply_fake_sparsity",
21+
"apply_fake_block_sparsity",
2022
"sparsify_"
2123
"semi_sparse_weight",
2224
"int8_dynamic_activation_int8_semi_sparse_weight"

torchao/sparsity/prototype/superblock/benchmark.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import utils
1414
from torch import nn
1515
from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm
16+
from torch.sparse._triton_ops_meta import dump as store_tuned_kernel_params
1617
from torchao.sparsity.prototype.superblock.utils import accelerate_with_sparsity, simulate_sparsity
1718
from torchao.utils import benchmark_model, profiler_runner
1819

@@ -34,15 +35,30 @@ def main(args):
3435
# BSR kernel tuning
3536
if args.bsr and args.tune_kernel_params:
3637
print("Tuning kernel params")
38+
kwargs = dict(
39+
dtype=torch.int8 if args.quantization else dtype,
40+
sparsity=args.sparsity_linear, verbose=True,
41+
# per blocksparse_int_addmm:
42+
alpha=1, beta=0, use_left_alpha=True, use_right_alpha=True,
43+
# force tuning because existing tuning parameters are
44+
# computed for use_left/right_alpha=False, however, it
45+
# turns out that re-tuning for use_left/right_alpha=False
46+
# leads to the same set of tuning parametes:
47+
# force=True
48+
)
3749
if args.model == "vit_b_16":
38-
optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
39-
optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
50+
optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, **kwargs)
51+
optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, **kwargs)
4052
elif args.model == "vit_h_14":
41-
optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
42-
optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
53+
optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, **kwargs)
54+
optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, **kwargs)
4355
else:
4456
raise NotImplementedError("Tuning kernel params for this model is not supported yet.")
45-
57+
# Warning: the following call will overwrite the source code
58+
# of torch.sparse._triton_ops_meta (hence it is commented out
59+
# by default) but when used, it'll enables reusing the tuned
60+
# parameters in subsequent runs of this script:
61+
# store_tuned_kernel_params()
4662
print("Creating model")
4763
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
4864

0 commit comments

Comments
 (0)