Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/microbenchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
model_inference_time_in_ms,
string_to_config,
)
from torchao import quantization
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.sparsity.sparse_api import sparsify_

Expand All @@ -48,14 +46,16 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
# Use quantize_ to apply each quantization function to the model
m_copy = deepcopy(base_model).eval().to(config.device)
aoBaseConfig = string_to_config(
config.quantization, config.sparsity, high_precision_dtype=config.high_precision_dtype
config.quantization,
config.sparsity,
high_precision_dtype=config.high_precision_dtype,
)
if aoBaseConfig is not None and config.quantization is not None:
quantize_(m_copy, aoBaseConfig)
elif config.sparsity is not None and aoBaseConfig is not None:
sparsify_(m_copy, aoBaseConfig)
else:
pass # No quantization or sparsity specified, do nothing
pass # No quantization or sparsity specified, do nothing
if config.use_torch_compile:
print("Compiling model....")
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)
Expand Down
17 changes: 10 additions & 7 deletions benchmarks/microbenchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
generate_results_csv,
print_results,
)
from torchao import sparsity


def get_shapes_for_config(
Expand Down Expand Up @@ -68,16 +67,22 @@ def get_param_combinations(model_param):

return shapes, base_params


def get_quantization_sparsity_recipes(
quantization_recipes: str, sparsity_recipes: str
) -> List[Tuple[str, str]]:
"""Generate valid quantization and sparsity recipes."""

config_recipes = []
for quant_config, sparse_config in product(quantization_recipes, sparsity_recipes):
if sparse_config != "None" and quant_config != "baseline":
if sparse_config != "None":
if "semi" in sparse_config or "2:4" in sparse_config:
if "marlin" in quant_config or "int8dq" in quant_config or "float8dq" in quant_config:
if (
"marlin" in quant_config
or "int8dq" in quant_config
or "float8dq" in quant_config
or quant_config == "baseline"
):
pass
else:
continue
Expand All @@ -86,10 +91,8 @@ def get_quantization_sparsity_recipes(
else:
raise ValueError(f"Invalid sparsity recipe: {sparse_config}")
config_recipes.append((quant_config, sparse_config))
print('Generated config recipes: ', config_recipes)
return config_recipes



def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig]:
"""Load benchmark configurations from CLI arguments and YAML file."""
Expand All @@ -109,8 +112,8 @@ def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig
get_quantization_sparsity_recipes(
config.get("quantization_config_recipe_names", ["baseline"]),
config.get("sparsity_config_recipe_names", ["None"]),
),
shapes
),
shapes,
):
configs.append(
BenchmarkConfig(
Expand Down
7 changes: 5 additions & 2 deletions benchmarks/microbenchmarks/test/benchmark_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ benchmark_mode: "inference"
quantization_config_recipe_names:
- "baseline"
- "int4wo-32"
- "int4wo-128"
sparsity_config_recipe_name: "semi-sparse"
- "marlin"
sparsity_config_recipe_names:
- "None"
- "semi-sparse"
- "block"
output_dir: "benchmarks/microbenchmarks/results"
model_params:
- name: "small_bf16_linear"
Expand Down
40 changes: 40 additions & 0 deletions benchmarks/microbenchmarks/test/test_benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,46 @@ def test_run_inference(self):
self.assertIsInstance(result, BenchmarkResult)
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))

def test_run_inference_with_sparsity(self):
"""Test running inference with sparsity configurations"""
# Test with semi-sparse config
config = BenchmarkConfig(
quantization="marlin",
sparsity="semi-sparse",
params={
"high_precision_dtype": "torch.float32",
"use_torch_compile": False,
"device": "cpu",
"model_type": "linear",
},
shape_name="custom",
shape=[16, 32, 8],
output_dir=self.temp_dir,
benchmark_mode="inference",
)
result = run(config)
self.assertIsInstance(result, BenchmarkResult)
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))

# Test with block sparsity
config = BenchmarkConfig(
quantization="baseline",
sparsity="block",
params={
"high_precision_dtype": "torch.float32",
"use_torch_compile": False,
"device": "cpu",
"model_type": "linear",
},
shape_name="custom",
shape=[16, 32, 8],
output_dir=self.temp_dir,
benchmark_mode="inference",
)
result = run(config)
self.assertIsInstance(result, BenchmarkResult)
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))


if __name__ == "__main__":
unittest.main()
63 changes: 63 additions & 0 deletions benchmarks/microbenchmarks/test/test_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from benchmarks.microbenchmarks.benchmark_runner import (
get_param_combinations,
get_quantization_sparsity_recipes,
get_shapes_for_config,
load_benchmark_configs,
run_inference_benchmarks_from_config,
Expand Down Expand Up @@ -88,6 +89,68 @@ def test_run_inference_benchmarks_from_config(self):
results_file = Path(self.temp_dir) / "results.csv"
self.assertTrue(results_file.exists())

def test_get_quantization_sparsity_recipes(self):
"""Test generation of valid quantization and sparsity recipe combinations"""
# Test basic combinations
quant_recipes = ["baseline", "int8wo"]
sparse_recipes = ["None", "semi-sparse"]
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
self.assertEqual(
len(recipes), 3
) # Should only get baseline+None and int8wo+None
self.assertIn(("baseline", "None"), recipes)
self.assertIn(("int8wo", "None"), recipes)
self.assertIn(("baseline", "semi-sparse"), recipes)

# Test marlin with semi-sparse
quant_recipes = ["marlin", "baseline"]
sparse_recipes = ["None", "semi-sparse"]
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
self.assertIn(("marlin", "semi-sparse"), recipes)
self.assertIn(("baseline", "None"), recipes)

# Test block sparsity
quant_recipes = ["baseline"]
sparse_recipes = ["None", "block"]
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
self.assertIn(("baseline", "block"), recipes)

def test_load_benchmark_configs_with_sparsity(self):
"""Test loading benchmark configs with sparsity options"""
test_config = {
"benchmark_mode": "inference",
"quantization_config_recipe_names": ["baseline", "marlin"],
"sparsity_config_recipe_names": ["None", "semi-sparse"],
"output_dir": self.temp_dir,
"model_params": [
{
"matrix_shapes": [
{"name": "custom", "shapes": [[1024, 1024, 1024]]}
],
"high_precision_dtype": "torch.bfloat16",
"device": "cpu",
"model_type": "linear",
}
],
}

config_path = Path(self.temp_dir) / "test_sparsity_config.yml"
with open(config_path, "w") as f:
yaml.dump(test_config, f)

configs = load_benchmark_configs(argparse.Namespace(config=str(config_path)))

# Check that we get configs for baseline and marlin with appropriate sparsity
self.assertTrue(
any(c.quantization == "baseline" and c.sparsity == "None" for c in configs)
)
self.assertTrue(
any(
c.quantization == "marlin" and c.sparsity == "semi-sparse"
for c in configs
)
)


if __name__ == "__main__":
unittest.main()
43 changes: 40 additions & 3 deletions benchmarks/microbenchmarks/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from benchmarks.microbenchmarks.utils import (
BenchmarkConfig,
BenchmarkResult,
BlockSparseWeightConfig,
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Int4WeightOnlyConfig,
LNLinearSigmoid,
SemiSparseWeightConfig,
ToyLinearModel,
clean_caches,
create_model_and_input,
Expand All @@ -39,6 +43,7 @@ def setUp(self):
def test_benchmark_config(self):
config = BenchmarkConfig(
quantization="baseline",
sparsity="None",
params=self.test_params,
shape_name="custom",
shape=self.test_shape,
Expand All @@ -60,6 +65,7 @@ def test_benchmark_config(self):
def test_benchmark_result(self):
config = BenchmarkConfig(
quantization="baseline",
sparsity="None",
params=self.test_params,
shape_name="custom",
shape=self.test_shape,
Expand All @@ -82,17 +88,46 @@ def test_get_default_device(self):

def test_string_to_config(self):
# Test baseline
config = string_to_config("baseline")
config = string_to_config("baseline", "None")
self.assertIsNone(config)

# Test int8wo
config = string_to_config("int8wo")
config = string_to_config("int8wo", "None")
self.assertIsNotNone(config)

# Test invalid config
config = string_to_config("not_a_real_config")
config = string_to_config("not_a_real_config", "None")
self.assertIsNone(config)

def test_string_to_config_sparsity(self):
"""Test sparsity config generation"""
# Test semi-sparse config
config = string_to_config(None, "semi-sparse")
self.assertIsInstance(config, SemiSparseWeightConfig)

# Test block sparse config
config = string_to_config(None, "block")
self.assertIsInstance(config, BlockSparseWeightConfig)

# Test combined sparsity and quantization
config = string_to_config("marlin", "semi-sparse")
self.assertIsInstance(config, Int4WeightOnlyConfig)

# Test float8 with semi-sparse
config = string_to_config("float8dq", "semi-sparse")
self.assertIsInstance(
config, Float8DynamicActivationFloat8SemiSparseWeightConfig
)

def test_invalid_sparsity(self):
"""Test invalid sparsity config generation"""
from benchmarks.microbenchmarks.benchmark_runner import (
get_quantization_sparsity_recipes,
)

with self.assertRaises(ValueError):
get_quantization_sparsity_recipes(["baseline"], ["invalid_sparsity"])

def test_toy_linear_model(self):
model = ToyLinearModel(k=64, n=32, dtype=torch.float32)
x = torch.randn(16, 64)
Expand Down Expand Up @@ -139,6 +174,7 @@ def test_generate_results_csv(self):
BenchmarkResult(
BenchmarkConfig(
quantization="int8wo",
sparsity="None",
params={},
shape_name="custom",
shape=[1024, 1024, 1024],
Expand All @@ -149,6 +185,7 @@ def test_generate_results_csv(self):
BenchmarkResult(
BenchmarkConfig(
quantization="int4wo",
sparsity="None",
params={},
shape_name="custom",
shape=[1024, 1024, 1024],
Expand Down
17 changes: 12 additions & 5 deletions benchmarks/microbenchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
import csv
import os
from pickle import NONE
from typing import Any, Dict, List, Optional

import torch
Expand All @@ -14,6 +13,7 @@

from torchao.core.config import AOBaseConfig
from torchao.quantization import (
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
FPXWeightOnlyConfig,
Expand All @@ -25,7 +25,6 @@
PerRow,
PerTensor,
UIntXWeightOnlyConfig,
Float8DynamicActivationFloat8SemiSparseWeightConfig,
)
from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig

Expand Down Expand Up @@ -54,7 +53,9 @@ def get_default_device(device: str = "cuda") -> str:
class BenchmarkConfig:
def __init__(
self,
quantization: Optional[str], # Quantization string format is similar to the format being used for llama/generate.py
quantization: Optional[
str
], # Quantization string format is similar to the format being used for llama/generate.py
sparsity: Optional[str], # Specify the type of sparsity to be used
params: Dict[str, Any],
shape_name: str,
Expand Down Expand Up @@ -147,7 +148,9 @@ def forward(self, x):
return x


def string_to_config(quantization: Optional[str], sparsity: Optional[str], **kwargs) -> AOBaseConfig:
def string_to_config(
quantization: Optional[str], sparsity: Optional[str], **kwargs
) -> AOBaseConfig:
"""Get quantization config based on quantization string.

Args:
Expand All @@ -166,6 +169,7 @@ def string_to_config(quantization: Optional[str], sparsity: Optional[str], **kwa
if quantization is None and sparsity is None:
return None
high_precision_dtype = kwargs.get("high_precision_dtype", torch.bfloat16)

if "int4wo" in quantization and not HAS_TRITON:
print("Warning: Triton not available, falling back to baseline")
return None
Expand All @@ -177,6 +181,8 @@ def string_to_config(quantization: Optional[str], sparsity: Optional[str], **kwa
return Int8WeightOnlyConfig()
if "int8dq" in quantization:
if sparsity is not None and ("semi" in sparsity or "2:4" in sparsity):
from torchao.dtypes import SemiSparseLayout

return Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout())
elif "int8dq_prefill_wo_decode" in quantization:
return Int8DynamicActivationInt8WeightConfig(weight_only_decode=True)
Expand Down Expand Up @@ -215,6 +221,7 @@ def string_to_config(quantization: Optional[str], sparsity: Optional[str], **kwa
)
elif sparsity is not None and ("semi" in sparsity or "2:4" in sparsity):
from torchao.dtypes import MarlinSparseLayout

return Int4WeightOnlyConfig(layout=MarlinSparseLayout())
if "fp6" in quantization:
return FPXWeightOnlyConfig(3, 2)
Expand Down Expand Up @@ -265,7 +272,7 @@ def string_to_config(quantization: Optional[str], sparsity: Optional[str], **kwa
return Float8WeightOnlyConfig()
elif "float8dq" in quantization:
if sparsity and "semi" in sparsity:
return Float8DynamicActivationFloat8SemiSparseWeightConfig()
return Float8DynamicActivationFloat8SemiSparseWeightConfig()
granularity = str(quantization.split("-")[-1])
if granularity == "tensor":
granularity = PerTensor()
Expand Down
Loading