Skip to content

Commit 2193189

Browse files
committed
Updates
1 parent f400fef commit 2193189

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

benchmarks/microbenchmarks/benchmark_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
clean_caches,
1717
create_model_and_input,
1818
model_inference_time_in_ms,
19-
quantization_string_to_quantization_config,
19+
string_to_config,
2020
)
2121
from torchao.quantization import quantize_
2222

@@ -39,10 +39,10 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
3939

4040
# Use quantize_ to apply each quantization function to the model
4141
m_copy = deepcopy(base_model).eval().to(config.device)
42-
quantization_config = quantization_string_to_quantization_config(
42+
quantization_config = string_to_config(
4343
config.quantization, high_precision_dtype=config.high_precision_dtype
4444
)
45-
if quantization_config:
45+
if quantization_config is not None:
4646
quantize_(m_copy, quantization_config)
4747
if config.use_torch_compile:
4848
print("Compiling model....")

benchmarks/microbenchmarks/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,7 @@ def get_default_device() -> str:
125125
return "cpu"
126126

127127

128-
def quantization_string_to_quantization_config(
129-
quantization: str, **kwargs
130-
) -> AOBaseConfig:
128+
def string_to_config(quantization: str, **kwargs) -> AOBaseConfig:
131129
"""Get quantization config based on quantization string.
132130
133131
Args:

0 commit comments

Comments
 (0)