Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions benchmarks/microbenchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,26 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
config.sparsity,
high_precision_dtype=config.high_precision_dtype,
)
if config.sparsity != "None" and config.quantization == "baseline":
print(f"Sparsifying model for sparsity: {config.sparsity}")
sparsify_(m_copy, aoBaseConfig)
elif config.sparsity == "None" and config.quantization == "baseline":

# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
is_cuda = config.device == "cuda" and torch.cuda.is_available()

if config.sparsity is not None and (
config.quantization is None or "baseline" in config.quantization
):
if is_cuda:
print(f"Applying {config.sparsity} sparsity to model")
sparsify_(m_copy, aoBaseConfig)
else:
print(
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
)
elif config.sparsity is None and (
config.quantization is None or "baseline" in config.quantization
):
pass # No quantization or sparsity specified, do nothing
else:
print(f"Quantizing model with quantization: {config.quantization}, sparsity: {config.sparsity}")
print("Quantizing model....")
quantize_(m_copy, aoBaseConfig)

if config.use_torch_compile:
Expand Down
84 changes: 58 additions & 26 deletions benchmarks/microbenchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import argparse
from itertools import product
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple

import yaml

Expand Down Expand Up @@ -69,28 +69,57 @@ def get_param_combinations(model_param):


def get_quantization_sparsity_recipes(
quantization_recipes: str, sparsity_recipes: str
) -> List[Tuple[str, str]]:
"""Generate valid quantization and sparsity recipes."""

config_recipes = []
for quant_config, sparse_config in product(quantization_recipes, sparsity_recipes):
if sparse_config != "None":
if "semi" in sparse_config or "2:4" in sparse_config:
quantization_recipes: List[str], sparsity_recipes: List[str]
) -> Set[Tuple[str, Optional[str]]]:
"""Generate valid quantization and sparsity recipes.

Args:
quantization_recipes: List of quantization recipes
sparsity_recipes: List of sparsity recipes

Returns:
Set of tuples containing (quantization_recipe, sparsity_recipe)
For block sparsity, quantization is always "baseline"
All quantization techniques are also run without sparsity
"""
config_recipes = set()

# Handle edge cases
if sparsity_recipes is None and quantization_recipes is None:
return {("baseline", None)}
if sparsity_recipes is None:
return {(quant, None) for quant in quantization_recipes}
if quantization_recipes is None:
return {("baseline", sparse) for sparse in sparsity_recipes}

# Always include baseline without sparsity
config_recipes.add(("baseline", None))

# Add all quantization techniques without sparsity
for quant_config in quantization_recipes:
config_recipes.add((quant_config, None))

# Process combinations of quantization and sparsity
for sparse_config in sparsity_recipes:
if sparse_config is None:
# Skip None sparsity as we've already added all quantization techniques without sparsity
continue
elif "block" in sparse_config:
# For block sparsity, only pair with baseline quantization
config_recipes.add(("baseline", sparse_config))
elif "semi" in sparse_config or "2:4" in sparse_config:
# For semi-sparse, only pair with compatible quantization methods
for quant_config in quantization_recipes:
if (
"marlin" in quant_config
or "int8dq" in quant_config
or "float8dq" in quant_config
or quant_config == "baseline"
):
pass
else:
continue
elif sparse_config == "block":
config_recipes.append(("baseline", sparse_config))
else:
raise ValueError(f"Invalid sparsity recipe: {sparse_config}")
config_recipes.append((quant_config, sparse_config))
config_recipes.add((quant_config, sparse_config))
else:
raise ValueError(f"Invalid sparsity recipe: {sparse_config}")

return config_recipes


Expand All @@ -104,15 +133,16 @@ def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig

# Create all possible combinations
configs = []
quantization_sparsity_recipes = get_quantization_sparsity_recipes(
config.get("quantization_config_recipe_names", None),
config.get("sparsity_config_recipe_names", None),
)
for model_param in config["model_params"]:
shapes, params = get_param_combinations(model_param)

# Create configs for all combinations
for (quant_config, sparse_config), (shape_name, shape) in product(
get_quantization_sparsity_recipes(
config.get("quantization_config_recipe_names", ["baseline"]),
config.get("sparsity_config_recipe_names", ["None"]),
),
quantization_sparsity_recipes,
shapes,
):
configs.append(
Expand All @@ -126,7 +156,6 @@ def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig
benchmark_mode=benchmark_mode,
)
)

return configs


Expand All @@ -135,14 +164,17 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None
from benchmarks.microbenchmarks.benchmark_inference import run as run_inference

results = []
print("Benchmarking Inference ......")
print("----------------- RUNNING BENCHMARKS FOR INFERENCE -----------------------")
for config in configs:
print("----------------------------------------")
try:
print(f"Running: {config.name}")
print(
f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}"
)
result = run_inference(config) # Pass the config object directly
results.append(result)
except Exception as e:
print(f"Error running benchmark {config.name}: {e}")
except Exception:
print(f"Error running benchmark {config.name}")
continue

# Add results to csv
Expand Down
14 changes: 0 additions & 14 deletions benchmarks/microbenchmarks/results/results.csv

This file was deleted.

4 changes: 2 additions & 2 deletions benchmarks/microbenchmarks/test/benchmark_config.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Sample configuration for inference benchmarks
benchmark_mode: "inference"
quantization_config_recipe_names:
- "baseline"
# - "baseline" Will always run a baseline instatance
- "int4wo-32"
- "marlin"
sparsity_config_recipe_names:
- "None"
# - "none" Will always run a without sparsity instance
- "semi-sparse"
- "block"
output_dir: "benchmarks/microbenchmarks/results"
Expand Down
24 changes: 20 additions & 4 deletions benchmarks/microbenchmarks/test/test_benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import tempfile
import unittest
from unittest.mock import patch

from benchmarks.microbenchmarks.benchmark_inference import run
from benchmarks.microbenchmarks.utils import BenchmarkConfig, BenchmarkResult
Expand Down Expand Up @@ -36,14 +37,28 @@ def tearDown(self):

shutil.rmtree(self.temp_dir)

def test_run_inference(self):
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
def test_run_inference(self, mock_string_to_config):
# Mock string_to_config to return a valid config
from torchao.sparsity.sparse_api import SemiSparseWeightConfig

mock_string_to_config.return_value = SemiSparseWeightConfig()

result = run(self.config)
self.assertIsInstance(result, BenchmarkResult)
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))

def test_run_inference_with_sparsity(self):
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
def test_run_inference_with_sparsity(self, mock_string_to_config):
"""Test running inference with sparsity configurations"""
# Mock string_to_config to return valid configs
from torchao.quantization import Int4WeightOnlyConfig
from torchao.sparsity.sparse_api import (
BlockSparseWeightConfig,
)

# Test with semi-sparse config
mock_string_to_config.return_value = Int4WeightOnlyConfig()
config = BenchmarkConfig(
quantization="marlin",
sparsity="semi-sparse",
Expand All @@ -54,7 +69,7 @@ def test_run_inference_with_sparsity(self):
"model_type": "linear",
},
shape_name="custom",
shape=[16, 32, 8],
shape=[64, 64, 64], # Use dimensions divisible by 64
output_dir=self.temp_dir,
benchmark_mode="inference",
)
Expand All @@ -63,6 +78,7 @@ def test_run_inference_with_sparsity(self):
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))

# Test with block sparsity
mock_string_to_config.return_value = BlockSparseWeightConfig()
config = BenchmarkConfig(
quantization="baseline",
sparsity="block",
Expand All @@ -73,7 +89,7 @@ def test_run_inference_with_sparsity(self):
"model_type": "linear",
},
shape_name="custom",
shape=[16, 32, 8],
shape=[64, 64, 64], # Use dimensions divisible by 64
output_dir=self.temp_dir,
benchmark_mode="inference",
)
Expand Down
65 changes: 53 additions & 12 deletions benchmarks/microbenchmarks/test/test_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,34 +93,75 @@ def test_get_quantization_sparsity_recipes(self):
"""Test generation of valid quantization and sparsity recipe combinations"""
# Test basic combinations
quant_recipes = ["baseline", "int8wo"]
sparse_recipes = ["None", "semi-sparse"]
sparse_recipes = [None, "semi-sparse"]
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
self.assertEqual(
len(recipes), 3
) # Should only get baseline+None and int8wo+None
self.assertIn(("baseline", "None"), recipes)
self.assertIn(("int8wo", "None"), recipes)
self.assertIn(("baseline", None), recipes)
self.assertIn(("int8wo", None), recipes)
self.assertIn(("baseline", "semi-sparse"), recipes)

# Test marlin with semi-sparse
quant_recipes = ["marlin", "baseline"]
sparse_recipes = ["None", "semi-sparse"]
sparse_recipes = [None, "semi-sparse"]
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
self.assertIn(("marlin", "semi-sparse"), recipes)
self.assertIn(("baseline", "None"), recipes)
self.assertIn(("baseline", None), recipes)

# Test block sparsity
quant_recipes = ["baseline"]
sparse_recipes = ["None", "block"]
sparse_recipes = [None, "block"]
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
self.assertIn(("baseline", "block"), recipes)

def test_load_benchmark_configs_with_sparsity(self):
def test_none_string_raises_error(self):
"""Test that passing 'None' as a string raises an error"""
quant_recipes = ["baseline"]
sparse_recipes = ["None"] # "None" as a string should raise an error
with self.assertRaises(ValueError):
get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)

def test_block_sparsity_with_quantization(self):
"""Test that block sparsity is only paired with baseline quantization"""
quant_recipes = ["baseline", "int8wo", "int4wo", "marlin"]
sparse_recipes = ["block"]
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)

# Block sparsity should only be paired with baseline
self.assertIn(("baseline", "block"), recipes)
self.assertNotIn(("int8wo", "block"), recipes)
self.assertNotIn(("int4wo", "block"), recipes)
self.assertNotIn(("marlin", "block"), recipes)

# All quantization techniques should be run without sparsity
self.assertIn(("baseline", None), recipes)
self.assertIn(("int8wo", None), recipes)
self.assertIn(("int4wo", None), recipes)
self.assertIn(("marlin", None), recipes)

def test_all_quantization_without_sparsity(self):
"""Test that all quantization techniques are run without sparsity"""
quant_recipes = ["baseline", "int8wo", "int4wo", "marlin"]
sparse_recipes = [None, "semi-sparse", "block"]
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)

# All quantization techniques should be run without sparsity
for quant in quant_recipes:
self.assertIn((quant, None), recipes)

@patch(
"benchmarks.microbenchmarks.benchmark_runner.get_quantization_sparsity_recipes"
)
def test_load_benchmark_configs_with_sparsity(self, mock_get_recipes):
"""Test loading benchmark configs with sparsity options"""
# Mock get_quantization_sparsity_recipes to return a valid set of recipes
mock_get_recipes.return_value = {("baseline", None), ("marlin", "semi-sparse")}

test_config = {
"benchmark_mode": "inference",
"quantization_config_recipe_names": ["baseline", "marlin"],
"sparsity_config_recipe_names": ["None", "semi-sparse"],
"sparsity_config_recipe_names": [
None,
"semi-sparse",
], # Use None instead of "None"
"output_dir": self.temp_dir,
"model_params": [
{
Expand All @@ -142,7 +183,7 @@ def test_load_benchmark_configs_with_sparsity(self):

# Check that we get configs for baseline and marlin with appropriate sparsity
self.assertTrue(
any(c.quantization == "baseline" and c.sparsity == "None" for c in configs)
any(c.quantization == "baseline" and c.sparsity is None for c in configs)
)
self.assertTrue(
any(
Expand Down
17 changes: 17 additions & 0 deletions benchmarks/microbenchmarks/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,23 @@ def test_string_to_config_sparsity(self):
config, Float8DynamicActivationFloat8SemiSparseWeightConfig
)

def test_block_sparsity_with_baseline_quantization(self):
"""Test that block sparsity with baseline quantization returns BlockSparseWeightConfig"""
config = string_to_config("baseline", "block")
self.assertIsInstance(config, BlockSparseWeightConfig)

def test_block_sparsity_with_non_baseline_quantization(self):
"""Test that block sparsity with non-baseline quantization still returns BlockSparseWeightConfig"""
# Block sparsity should take precedence over any quantization method
config = string_to_config("int8wo", "block")
self.assertIsInstance(config, BlockSparseWeightConfig)

config = string_to_config("int4wo", "block")
self.assertIsInstance(config, BlockSparseWeightConfig)

config = string_to_config("marlin", "block")
self.assertIsInstance(config, BlockSparseWeightConfig)

def test_invalid_sparsity(self):
"""Test invalid sparsity config generation"""
from benchmarks.microbenchmarks.benchmark_runner import (
Expand Down
Loading
Loading