Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added benchmarks/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions benchmarks/microbenchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Microbenchmarks

This directory contains microbenchmarking tools for measuring inference performance across different quantization methods and model architectures.

## Overview

The microbenchmarking system works as follows:

![Microbenchmarks Process Flow](../../docs/static/microbenchmarking_process_diagram.png)

## Components

![Microbenchmarks Flow](../../docs/static/microbenchmarks_code_flow_diagram.png)

- **benchmark_runner.py**: Main entry point that orchestrates the benchmarking process
- **benchmark_inference.py**: Handles model creation and inference benchmarking
- **utils.py**: Contains utility functions and configuration classes
- **test\/**: Test files and sample configurations

## Usage

1. Create a configuration YAML file (see example below)
2. Run the benchmark using:

```bash
python -m benchmarks.microbenchmarks.benchmark_runner --config path/to/config.yml
```

### Example Configuration

```yaml
# Sample configuration for inference benchmarks
quantization_config_recipe_names:
- "baseline"
- "int8wo"
- "int4wo-128"
- "int4wo-128-hqq"

output_dir: "benchmarks/microbenchmarks/results"

model_params:
matrix_shapes:
- name: "custom"
shapes: [
[1024, 1024, 1024], # [m, k, n]
[2048, 4096, 1024],
[4096, 4096, 1024]
]
high_precision_dtype: "torch.bfloat16"
compile: "max-autotune" # Options: "default", "max-autotune", "false"
device: "cuda" # Options: "cuda", "mps", "xpu", "cpu"
model_type: "linear" # Options: "linear", "ln_linear_sigmoid"
```

## Configuration Options

### Quantization Methods
Currently, quantization string is in same format as the one being passed in llama/generate.py.
- `baseline`: No quantization
- `int8wo`: 8-bit weight-only quantization
- `int4wo-{group_size}`: 4-bit weight-only quantization with specified group size
- `int4wo-{group_size}-hqq`: 4-bit weight-only quantization with HQQ

### Model Types
- `linear`: Simple linear layer
- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid

### Device Options
- `cuda`: NVIDIA GPU
- `xpu`: Intel GPU
- `mps`: Apple Silicon GPU
- `cpu`: CPU fallback

## Output

Results are saved to a CSV file in the specified output directory

## Running Tests

To run the test suite:

```bash
python -m unittest discover benchmarks/microbenchmarks/test
```
Empty file.
75 changes: 75 additions & 0 deletions benchmarks/microbenchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Inference benchmark runner

This script runs inference benchmarks and generates a micro-benchmarking report for it.
- run() function is the main entry point for running inference benchmarks.
"""

from copy import deepcopy
from pathlib import Path

import torch

from benchmarks.microbenchmarks.utils import (
BenchmarkConfig,
BenchmarkResult,
clean_caches,
create_model_and_input,
model_inference_time_in_ms,
string_to_config,
)
from torchao.quantization import quantize_


def run(config: BenchmarkConfig) -> BenchmarkResult:
"""Run inference benchmarks"""
clean_caches() # Clean caches

# Create output directory if it doesn't exist
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

base_model, input_data = create_model_and_input(
config.model_type,
config.m,
config.k,
config.n,
high_precision_dtype=config.high_precision_dtype,
device=config.device,
)

# Use quantize_ to apply each quantization function to the model
m_copy = deepcopy(base_model).eval().to(config.device)
quantization_config = string_to_config(
config.quantization, high_precision_dtype=config.high_precision_dtype
)
if quantization_config is not None:
quantize_(m_copy, quantization_config)
if config.use_torch_compile:
print("Compiling model....")
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)

# Run benchmarks
result = BenchmarkResult(config=config)

# Benchmark time to run an inference call for quantized model
result.model_inference_time_in_ms = model_inference_time_in_ms(
model=m_copy, input_data=input_data
)

# TODO: Benchmark time using profiler
# Profile dtype model evaluation
# prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype)
# prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details

# TODO: Benchmark gemm time using cuda graph
# gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs)

# TODO: Benchmark op with cuda graph
# time = benchmark_op_with_cuda_graph(op, args)

return result
153 changes: 153 additions & 0 deletions benchmarks/microbenchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Benchmark Runner

This is the main entry point for the benchmarking application. It reads the YAML configuration
file and orchestrates the entire benchmarking process by:
- Loading and validating benchmark configurations
- Executing benchmark scenarios
- Collecting and processing results
- Generating reports

Usage:
python benchmark_runner.py [config.yaml]

The YAML file should contain all necessary configuration parameters for the benchmarks.
"""

import argparse
from itertools import product
from typing import Any, Dict, List, Tuple

import yaml

from benchmarks.microbenchmarks.utils import (
BenchmarkConfig,
generate_results_csv,
print_results,
)


def get_shapes_for_config(
shape_configs: List[Dict[str, Any]],
) -> List[Tuple[str, List[int]]]:
"""Get shapes for a given configuration.

Args:
shape_configs: List of shape configurations from YAML

Returns:
List of tuples containing (shape_name, shape)
"""
shapes = []
for shape_config in shape_configs:
name = shape_config["name"]
if name == "custom":
shapes.extend([(name, shape) for shape in shape_config["shapes"]])
else:
raise NotImplementedError(
f"Shape config {name} not supported. Currently only supports custom shapes."
)
return shapes


def get_param_combinations(model_param):
"""Extract all parameter combinations from a model config"""
# Get all shapes
shapes = get_shapes_for_config(model_param["matrix_shapes"])

# Extract all other parameters (excluding matrix_shapes)
base_params = {
key: value for key, value in model_param.items() if key not in ["matrix_shapes"]
}

return shapes, base_params


def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig]:
"""Load benchmark configurations from CLI arguments and YAML file."""
with open(cli_args.config, "r") as f:
config = yaml.safe_load(f)

output_dir = config.get("output_dir", "benchmarks/microbenchmarks/results")
benchmark_mode = config.get("benchmark_mode", "inference")

# Create all possible combinations
configs = []
for model_param in config["model_params"]:
shapes, params = get_param_combinations(model_param)

# Create configs for all combinations
for quant_config, (shape_name, shape) in product(
config.get("quantization_config_recipe_names", ["baseline"]), shapes
):
configs.append(
BenchmarkConfig(
quantization=quant_config,
params=params,
shape_name=shape_name,
shape=shape,
output_dir=output_dir,
benchmark_mode=benchmark_mode,
)
)

return configs


def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None:
"""Run benchmarks using configurations from YAML file"""
from benchmarks.microbenchmarks.benchmark_inference import run as run_inference

results = []
print("Benchmarking Inference ......")
for config in configs:
try:
print(f"Running: {config.name}")
result = run_inference(config) # Pass the config object directly
results.append(result)
except Exception as e:
print(f"Error running benchmark {config.name}: {e}")
continue

# Add results to csv
generate_results_csv(results, configs[0].output_dir)

# Print results
print_results(results)

# TODO: Process results: Speedups:
# 1. For different shapes for same model and quantization
# 2. For different quantizations for same model and shape
# 3. For different models for same quantization


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Run benchmarks from config file")
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to benchmark configuration file",
)
# TODO: Add support for args to override config values and run smaller benchmarks
args = parser.parse_args()

configs = load_benchmark_configs(cli_args=args)
# Run benchmarks
if configs[0].benchmark_mode == "inference":
run_inference_benchmarks_from_config(configs)
elif configs[0].benchmark_mode == "training":
print("Training mode not implemented yet")
else:
raise ValueError(
f"Invalid benchmark mode: {configs[0].benchmark_mode}, choose from inference or training"
)

# TODO: Add support for args to override config values and run smaller benchmarks
Empty file.
43 changes: 43 additions & 0 deletions benchmarks/microbenchmarks/test/benchmark_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Sample configuration for inference benchmarks
benchmark_mode: "inference"
quantization_config_recipe_names:
- "baseline"
- "int4wo-32"
- "int4wo-128"
output_dir: "benchmarks/microbenchmarks/results"
model_params:
- name: "small_bf16_linear"
matrix_shapes:
- name: "custom"
shapes: [
[1024, 1024, 1024], # [m, k, n]
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "linear"

- name: "large_bf16_ln_linear"
matrix_shapes:
- name: "custom"
shapes: [
[2048, 4096, 1024],
[4096, 4096, 1024]
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "ln_linear_sigmoid"

- name: "cpu_fp32_linear"
matrix_shapes:
- name: "custom"
shapes: [
[4096, 4096, 1024]
]
high_precision_dtype: "torch.float32"
use_torch_compile: false
device: "cpu"
model_type: "linear"
Loading
Loading