From d9fdd6919d972767e84cf43f5106c681383dc00d Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 25 Mar 2025 15:28:09 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- .../microbenchmarks/benchmark_inference.py | 23 +++-- .../microbenchmarks/benchmark_runner.py | 84 +++++++++++++------ .../microbenchmarks/results/results.csv | 14 ---- .../microbenchmarks/test/benchmark_config.yml | 4 +- .../test/test_benchmark_inference.py | 24 +++++- .../test/test_benchmark_runner.py | 65 +++++++++++--- benchmarks/microbenchmarks/test/test_utils.py | 17 ++++ benchmarks/microbenchmarks/utils.py | 12 ++- 8 files changed, 178 insertions(+), 65 deletions(-) delete mode 100644 benchmarks/microbenchmarks/results/results.csv diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index ea91deca12..d84e21f382 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -50,13 +50,26 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: config.sparsity, high_precision_dtype=config.high_precision_dtype, ) - if config.sparsity != "None" and config.quantization == "baseline": - print(f"Sparsifying model for sparsity: {config.sparsity}") - sparsify_(m_copy, aoBaseConfig) - elif config.sparsity == "None" and config.quantization == "baseline": + + # Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA) + is_cuda = config.device == "cuda" and torch.cuda.is_available() + + if config.sparsity is not None and ( + config.quantization is None or "baseline" in config.quantization + ): + if is_cuda: + print(f"Applying {config.sparsity} sparsity to model") + sparsify_(m_copy, aoBaseConfig) + else: + print( + f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}" + ) + elif config.sparsity is None and ( + config.quantization is None or "baseline" in config.quantization + ): pass # No quantization or sparsity specified, do nothing else: - print(f"Quantizing model with quantization: {config.quantization}, sparsity: {config.sparsity}") + print("Quantizing model....") quantize_(m_copy, aoBaseConfig) if config.use_torch_compile: diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 7ba6d6b6dc..85a8ef2f53 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -21,7 +21,7 @@ import argparse from itertools import product -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import yaml @@ -69,28 +69,57 @@ def get_param_combinations(model_param): def get_quantization_sparsity_recipes( - quantization_recipes: str, sparsity_recipes: str -) -> List[Tuple[str, str]]: - """Generate valid quantization and sparsity recipes.""" - - config_recipes = [] - for quant_config, sparse_config in product(quantization_recipes, sparsity_recipes): - if sparse_config != "None": - if "semi" in sparse_config or "2:4" in sparse_config: + quantization_recipes: List[str], sparsity_recipes: List[str] +) -> Set[Tuple[str, Optional[str]]]: + """Generate valid quantization and sparsity recipes. + + Args: + quantization_recipes: List of quantization recipes + sparsity_recipes: List of sparsity recipes + + Returns: + Set of tuples containing (quantization_recipe, sparsity_recipe) + For block sparsity, quantization is always "baseline" + All quantization techniques are also run without sparsity + """ + config_recipes = set() + + # Handle edge cases + if sparsity_recipes is None and quantization_recipes is None: + return {("baseline", None)} + if sparsity_recipes is None: + return {(quant, None) for quant in quantization_recipes} + if quantization_recipes is None: + return {("baseline", sparse) for sparse in sparsity_recipes} + + # Always include baseline without sparsity + config_recipes.add(("baseline", None)) + + # Add all quantization techniques without sparsity + for quant_config in quantization_recipes: + config_recipes.add((quant_config, None)) + + # Process combinations of quantization and sparsity + for sparse_config in sparsity_recipes: + if sparse_config is None: + # Skip None sparsity as we've already added all quantization techniques without sparsity + continue + elif "block" in sparse_config: + # For block sparsity, only pair with baseline quantization + config_recipes.add(("baseline", sparse_config)) + elif "semi" in sparse_config or "2:4" in sparse_config: + # For semi-sparse, only pair with compatible quantization methods + for quant_config in quantization_recipes: if ( "marlin" in quant_config or "int8dq" in quant_config or "float8dq" in quant_config or quant_config == "baseline" ): - pass - else: - continue - elif sparse_config == "block": - config_recipes.append(("baseline", sparse_config)) - else: - raise ValueError(f"Invalid sparsity recipe: {sparse_config}") - config_recipes.append((quant_config, sparse_config)) + config_recipes.add((quant_config, sparse_config)) + else: + raise ValueError(f"Invalid sparsity recipe: {sparse_config}") + return config_recipes @@ -104,15 +133,16 @@ def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig # Create all possible combinations configs = [] + quantization_sparsity_recipes = get_quantization_sparsity_recipes( + config.get("quantization_config_recipe_names", None), + config.get("sparsity_config_recipe_names", None), + ) for model_param in config["model_params"]: shapes, params = get_param_combinations(model_param) # Create configs for all combinations for (quant_config, sparse_config), (shape_name, shape) in product( - get_quantization_sparsity_recipes( - config.get("quantization_config_recipe_names", ["baseline"]), - config.get("sparsity_config_recipe_names", ["None"]), - ), + quantization_sparsity_recipes, shapes, ): configs.append( @@ -126,7 +156,6 @@ def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig benchmark_mode=benchmark_mode, ) ) - return configs @@ -135,14 +164,17 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None from benchmarks.microbenchmarks.benchmark_inference import run as run_inference results = [] - print("Benchmarking Inference ......") + print("----------------- RUNNING BENCHMARKS FOR INFERENCE -----------------------") for config in configs: + print("----------------------------------------") try: - print(f"Running: {config.name}") + print( + f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}" + ) result = run_inference(config) # Pass the config object directly results.append(result) - except Exception as e: - print(f"Error running benchmark {config.name}: {e}") + except Exception: + print(f"Error running benchmark {config.name}") continue # Add results to csv diff --git a/benchmarks/microbenchmarks/results/results.csv b/benchmarks/microbenchmarks/results/results.csv deleted file mode 100644 index d1e2fe188e..0000000000 --- a/benchmarks/microbenchmarks/results/results.csv +++ /dev/null @@ -1,14 +0,0 @@ -name,quantization,m,k,n,high_precision_dtype,use_torch_compile,torch_compile_mode,device,model_type,output_dir,model_inference_time_in_ms -small_bf16_linear,baseline,1024,1024,1024,torch.bfloat16,True,max-autotune,cuda,linear,benchmarks/microbenchmarks/results,47.870889538899064 -small_bf16_linear,int4wo-32,1024,1024,1024,torch.bfloat16,True,max-autotune,cuda,linear,benchmarks/microbenchmarks/results,92.79820020310581 -small_bf16_linear,int4wo-32,1024,1024,1024,torch.bfloat16,True,max-autotune,cuda,linear,benchmarks/microbenchmarks/results,92.79379970394075 -small_bf16_linear,marlin,1024,1024,1024,torch.bfloat16,True,max-autotune,cuda,linear,benchmarks/microbenchmarks/results,2827.8696595225483 -large_bf16_ln_linear,baseline,2048,4096,1024,torch.bfloat16,True,max-autotune,cuda,ln_linear_sigmoid,benchmarks/microbenchmarks/results,69.42827021703124 -large_bf16_ln_linear,baseline,4096,4096,1024,torch.bfloat16,True,max-autotune,cuda,ln_linear_sigmoid,benchmarks/microbenchmarks/results,110.21242011338472 -large_bf16_ln_linear,int4wo-32,2048,4096,1024,torch.bfloat16,True,max-autotune,cuda,ln_linear_sigmoid,benchmarks/microbenchmarks/results,640.171259874478 -large_bf16_ln_linear,int4wo-32,4096,4096,1024,torch.bfloat16,True,max-autotune,cuda,ln_linear_sigmoid,benchmarks/microbenchmarks/results,1230.195889947936 -large_bf16_ln_linear,int4wo-32,2048,4096,1024,torch.bfloat16,True,max-autotune,cuda,ln_linear_sigmoid,benchmarks/microbenchmarks/results,638.9500107616186 -large_bf16_ln_linear,int4wo-32,4096,4096,1024,torch.bfloat16,True,max-autotune,cuda,ln_linear_sigmoid,benchmarks/microbenchmarks/results,1224.7836799360812 -large_bf16_ln_linear,marlin,2048,4096,1024,torch.bfloat16,True,max-autotune,cuda,ln_linear_sigmoid,benchmarks/microbenchmarks/results,10065.855470020324 -large_bf16_ln_linear,marlin,4096,4096,1024,torch.bfloat16,True,max-autotune,cuda,ln_linear_sigmoid,benchmarks/microbenchmarks/results,11008.49203998223 -cpu_fp32_linear,baseline,4096,4096,1024,torch.float32,False,,cpu,linear,benchmarks/microbenchmarks/results,369783.37747976184 diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 074dfeabf4..17cd666bfa 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -1,11 +1,11 @@ # Sample configuration for inference benchmarks benchmark_mode: "inference" quantization_config_recipe_names: - - "baseline" + # - "baseline" Will always run a baseline instatance - "int4wo-32" - "marlin" sparsity_config_recipe_names: - - "None" + # - "none" Will always run a without sparsity instance - "semi-sparse" - "block" output_dir: "benchmarks/microbenchmarks/results" diff --git a/benchmarks/microbenchmarks/test/test_benchmark_inference.py b/benchmarks/microbenchmarks/test/test_benchmark_inference.py index 55052acdaa..0216297bc7 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_inference.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_inference.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import tempfile import unittest +from unittest.mock import patch from benchmarks.microbenchmarks.benchmark_inference import run from benchmarks.microbenchmarks.utils import BenchmarkConfig, BenchmarkResult @@ -36,14 +37,28 @@ def tearDown(self): shutil.rmtree(self.temp_dir) - def test_run_inference(self): + @patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config") + def test_run_inference(self, mock_string_to_config): + # Mock string_to_config to return a valid config + from torchao.sparsity.sparse_api import SemiSparseWeightConfig + + mock_string_to_config.return_value = SemiSparseWeightConfig() + result = run(self.config) self.assertIsInstance(result, BenchmarkResult) self.assertTrue(hasattr(result, "model_inference_time_in_ms")) - def test_run_inference_with_sparsity(self): + @patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config") + def test_run_inference_with_sparsity(self, mock_string_to_config): """Test running inference with sparsity configurations""" + # Mock string_to_config to return valid configs + from torchao.quantization import Int4WeightOnlyConfig + from torchao.sparsity.sparse_api import ( + BlockSparseWeightConfig, + ) + # Test with semi-sparse config + mock_string_to_config.return_value = Int4WeightOnlyConfig() config = BenchmarkConfig( quantization="marlin", sparsity="semi-sparse", @@ -54,7 +69,7 @@ def test_run_inference_with_sparsity(self): "model_type": "linear", }, shape_name="custom", - shape=[16, 32, 8], + shape=[64, 64, 64], # Use dimensions divisible by 64 output_dir=self.temp_dir, benchmark_mode="inference", ) @@ -63,6 +78,7 @@ def test_run_inference_with_sparsity(self): self.assertTrue(hasattr(result, "model_inference_time_in_ms")) # Test with block sparsity + mock_string_to_config.return_value = BlockSparseWeightConfig() config = BenchmarkConfig( quantization="baseline", sparsity="block", @@ -73,7 +89,7 @@ def test_run_inference_with_sparsity(self): "model_type": "linear", }, shape_name="custom", - shape=[16, 32, 8], + shape=[64, 64, 64], # Use dimensions divisible by 64 output_dir=self.temp_dir, benchmark_mode="inference", ) diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py index 7e9c6d024e..a8683a1de8 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_runner.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -93,34 +93,75 @@ def test_get_quantization_sparsity_recipes(self): """Test generation of valid quantization and sparsity recipe combinations""" # Test basic combinations quant_recipes = ["baseline", "int8wo"] - sparse_recipes = ["None", "semi-sparse"] + sparse_recipes = [None, "semi-sparse"] recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) - self.assertEqual( - len(recipes), 3 - ) # Should only get baseline+None and int8wo+None - self.assertIn(("baseline", "None"), recipes) - self.assertIn(("int8wo", "None"), recipes) + self.assertIn(("baseline", None), recipes) + self.assertIn(("int8wo", None), recipes) self.assertIn(("baseline", "semi-sparse"), recipes) # Test marlin with semi-sparse quant_recipes = ["marlin", "baseline"] - sparse_recipes = ["None", "semi-sparse"] + sparse_recipes = [None, "semi-sparse"] recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) self.assertIn(("marlin", "semi-sparse"), recipes) - self.assertIn(("baseline", "None"), recipes) + self.assertIn(("baseline", None), recipes) # Test block sparsity quant_recipes = ["baseline"] - sparse_recipes = ["None", "block"] + sparse_recipes = [None, "block"] recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) self.assertIn(("baseline", "block"), recipes) - def test_load_benchmark_configs_with_sparsity(self): + def test_none_string_raises_error(self): + """Test that passing 'None' as a string raises an error""" + quant_recipes = ["baseline"] + sparse_recipes = ["None"] # "None" as a string should raise an error + with self.assertRaises(ValueError): + get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) + + def test_block_sparsity_with_quantization(self): + """Test that block sparsity is only paired with baseline quantization""" + quant_recipes = ["baseline", "int8wo", "int4wo", "marlin"] + sparse_recipes = ["block"] + recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) + + # Block sparsity should only be paired with baseline + self.assertIn(("baseline", "block"), recipes) + self.assertNotIn(("int8wo", "block"), recipes) + self.assertNotIn(("int4wo", "block"), recipes) + self.assertNotIn(("marlin", "block"), recipes) + + # All quantization techniques should be run without sparsity + self.assertIn(("baseline", None), recipes) + self.assertIn(("int8wo", None), recipes) + self.assertIn(("int4wo", None), recipes) + self.assertIn(("marlin", None), recipes) + + def test_all_quantization_without_sparsity(self): + """Test that all quantization techniques are run without sparsity""" + quant_recipes = ["baseline", "int8wo", "int4wo", "marlin"] + sparse_recipes = [None, "semi-sparse", "block"] + recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) + + # All quantization techniques should be run without sparsity + for quant in quant_recipes: + self.assertIn((quant, None), recipes) + + @patch( + "benchmarks.microbenchmarks.benchmark_runner.get_quantization_sparsity_recipes" + ) + def test_load_benchmark_configs_with_sparsity(self, mock_get_recipes): """Test loading benchmark configs with sparsity options""" + # Mock get_quantization_sparsity_recipes to return a valid set of recipes + mock_get_recipes.return_value = {("baseline", None), ("marlin", "semi-sparse")} + test_config = { "benchmark_mode": "inference", "quantization_config_recipe_names": ["baseline", "marlin"], - "sparsity_config_recipe_names": ["None", "semi-sparse"], + "sparsity_config_recipe_names": [ + None, + "semi-sparse", + ], # Use None instead of "None" "output_dir": self.temp_dir, "model_params": [ { @@ -142,7 +183,7 @@ def test_load_benchmark_configs_with_sparsity(self): # Check that we get configs for baseline and marlin with appropriate sparsity self.assertTrue( - any(c.quantization == "baseline" and c.sparsity == "None" for c in configs) + any(c.quantization == "baseline" and c.sparsity is None for c in configs) ) self.assertTrue( any( diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index ae7e88c9ae..83e88c5b11 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -119,6 +119,23 @@ def test_string_to_config_sparsity(self): config, Float8DynamicActivationFloat8SemiSparseWeightConfig ) + def test_block_sparsity_with_baseline_quantization(self): + """Test that block sparsity with baseline quantization returns BlockSparseWeightConfig""" + config = string_to_config("baseline", "block") + self.assertIsInstance(config, BlockSparseWeightConfig) + + def test_block_sparsity_with_non_baseline_quantization(self): + """Test that block sparsity with non-baseline quantization still returns BlockSparseWeightConfig""" + # Block sparsity should take precedence over any quantization method + config = string_to_config("int8wo", "block") + self.assertIsInstance(config, BlockSparseWeightConfig) + + config = string_to_config("int4wo", "block") + self.assertIsInstance(config, BlockSparseWeightConfig) + + config = string_to_config("marlin", "block") + self.assertIsInstance(config, BlockSparseWeightConfig) + def test_invalid_sparsity(self): """Test invalid sparsity config generation""" from benchmarks.microbenchmarks.benchmark_runner import ( diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 076d9aa180..fd3db11591 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -95,6 +95,7 @@ def to_dict(self) -> Dict[str, Any]: return { "name": self.name, "quantization": self.quantization, + "sparsity": self.sparsity, "m": self.m, "k": self.k, "n": self.n, @@ -161,11 +162,16 @@ def string_to_config( Returns: AOBaseConfig: Quantization configuration object """ + # Handle block sparsity case - with block sparsity, quantization should always be "none" or "baseline" + if sparsity is not None and sparsity == "block": + return BlockSparseWeightConfig() + + # Handle other sparsity cases if quantization is None and sparsity is not None: if "semi" in sparsity or "2:4" in sparsity: return SemiSparseWeightConfig() - if sparsity == "block": - return BlockSparseWeightConfig() + else: + raise ValueError(f"Unknown sparsity type: {sparsity}") if quantization is None and sparsity is None: return None high_precision_dtype = kwargs.get("high_precision_dtype", torch.bfloat16) @@ -424,6 +430,8 @@ def print_results(results: List[BenchmarkResult]): row = [] for col in display_columns: value = result_dict.get(col, "N/A") + if value is None: + value = "N/A" if col == "model_inference_time_in_ms": value = f"{value:.2f}" if isinstance(value, (int, float)) else value elif col == "use_torch_compile":