Skip to content

Commit b6fc991

Browse files
committed
intial commit
1 parent f5703b0 commit b6fc991

File tree

12 files changed

+407
-21
lines changed

12 files changed

+407
-21
lines changed

test/sparsity/test_sparse_api.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,33 @@
2121
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
2222
from torch.testing._internal.common_utils import TestCase
2323

24+
from torch.ao.pruning import WeightNormSparsifier
25+
2426

2527
logging.basicConfig(
2628
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
2729
)
2830

31+
def apply_fake_block_sparsity(model, **kwargs):
32+
"""
33+
This function simulates 2:4 sparsity on all linear layers in a model.
34+
It uses the torch.ao.pruning flow.
35+
"""
36+
filter_fn = kwargs.pop("filter_fn", _is_linear)
37+
# torch.ao.pruning flow
38+
sparse_config = []
39+
for name, mod in model.named_modules():
40+
if filter_fn(mod, name):
41+
sparse_config.append({"tensor_fqn": f"{name}.weight"})
42+
43+
sparsifier = WeightNormSparsifier(
44+
sparsity_level=0.5, sparse_block_shape=(64, 64)
45+
)
46+
sparsifier.prepare(model, sparse_config)
47+
sparsifier.step()
48+
sparsifier.squash_mask()
49+
50+
2951
class TestSemiStructuredSparse(TestCase):
3052

3153
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
@@ -73,5 +95,70 @@ def test_quant_semi_sparse(self):
7395

7496
assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)
7597

98+
99+
100+
class TestBlockSparseWeight(TestCase):
101+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
102+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
103+
def test_sparse(self):
104+
input = torch.rand((1024, 1024)).half().cuda()
105+
model = (
106+
nn.Sequential(
107+
nn.Linear(1024, 2048),
108+
nn.Linear(2048, 1024),
109+
)
110+
.half()
111+
.cuda()
112+
)
113+
114+
from torchao.sparsity.utils import create_block_sparse_tensor
115+
M, N = model[0].weight.shape
116+
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
117+
M, N = model[1].weight.shape
118+
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
119+
dense_result = model(input)
120+
121+
from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight
122+
sparsify_(model, block_sparse_weight())
123+
sparse_result = model(input)
124+
125+
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
126+
127+
class TestQuantBlockSparseWeight(TestCase):
128+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
129+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
130+
def test_sparse(self):
131+
input = torch.rand((128, 128)).to(torch.bfloat16).cuda()
132+
model = (
133+
nn.Sequential(
134+
nn.Linear(128, 256),
135+
nn.Linear(256, 128),
136+
)
137+
.to(torch.bfloat16)
138+
.cuda()
139+
)
140+
141+
from torchao.sparsity.utils import create_block_sparse_tensor
142+
M, N = model[0].weight.shape
143+
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) * torch.rand(M, N, dtype=torch.bfloat16).cuda()
144+
print(model[0].weight)
145+
M, N = model[1].weight.shape
146+
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)
147+
print(model[1].weight)
148+
149+
model_copy = copy.deepcopy(model)
150+
151+
quantize_(model_copy, int8_dynamic_activation_int8_weight())
152+
reference = model_copy(input)
153+
154+
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
155+
quantize_(model, int8_dynamic_activation_int8_weight(layout_type=BlockSparseLayoutType(), ))
156+
sparse_result = model(input)
157+
158+
print(reference)
159+
print(sparse_result)
160+
assert torch.allclose(reference, sparse_result, rtol=1e-2, atol=1e-2)
161+
162+
76163
if __name__ == "__main__":
77164
unittest.main()

torchao/_models/sam/benchmark.sh

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# baseline
2-
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True
3-
# int8 dynamic quant (all)
4-
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant
5-
# 2:4 sparsity (all)
6-
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only
7-
# 2:4 sparsity (mlp only)
8-
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
9-
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
10-
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
2+
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True
3+
## int8 dynamic quant (all)
4+
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant
5+
## 2:4 sparsity (all)
6+
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only
7+
## 2:4 sparsity (mlp only)
8+
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
9+
## int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
10+
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
11+
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_block_sparse

torchao/_models/sam/eval_combo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,11 @@ def mlp_only(mod, name):
320320
mlp_lin2_only)
321321
if not TORCH_VERSION_AT_LEAST_2_5:
322322
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
323-
323+
elif compress == "int8_dynamic_quant_block_sparse":
324+
def mlp_only(mod, name):
325+
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
326+
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
327+
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight(layout_type=BlockSparseLayoutType()), mlp_only)
324328
else:
325329
assert compress is None, f"Unsupported compress mode {compress}"
326330

torchao/_models/sam/results.csv

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,9 @@ cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,ma
44
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
55
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
66
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
7+
device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path
8+
cuda,vit_h,32,15172,18,22.787559123509425,43.88359431477336,0.5809962729163862,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
9+
cuda,vit_h,32,15153,18,24.872293344547476,40.20537978333312,0.5821541984818872,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
10+
cuda,vit_h,32,15640,19,24.64409232721636,40.5776762528853,0.5674436009126148,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
11+
cuda,vit_h,32,13429,16,24.710537332827382,40.46856555691013,0.530554119734646,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
12+
cuda,vit_h,32,14869,18,26.5429434697436,37.67479673608557,0.566992236284673,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None

0 commit comments

Comments
 (0)