77import typing
88from dataclasses import dataclass
99
10- import habana_frameworks as htcore
1110import torch
12- from habana_quantization_toolkit ._core .common import mod_default_dict
13- from habana_quantization_toolkit ._quant_common .quant_config import Fp8cfg , QuantMode , ScaleMethod
11+ from neural_compressor . torch . algorithms . fp8_quant ._core .common import mod_default_dict
12+ from neural_compressor . torch . algorithms . fp8_quant ._quant_common .quant_config import Fp8cfg , QuantMode , ScaleMethod
1413
1514
1615@dataclass
@@ -60,8 +59,6 @@ def run_accuracy_test(
6059 This test also makes asserts the quantization actually happened.
6160 This may be moved to another tests in the future.
6261
63- You can use the generate_test_vectors.py script to generate input test vectors.
64-
6562 Args:
6663 module_class: The reference module class to test.
6764 This should be the direct module to test, e.g. Matmul, Linear, etc.
@@ -82,7 +79,7 @@ def run_accuracy_test(
8279 measure_vectors , test_vectors = itertools .tee (test_vectors )
8380
8481 for mode in [QuantMode .MEASURE , QuantMode .QUANTIZE ]:
85- import habana_quantization_toolkit . prepare_quant .prepare_model as hqt
82+ import neural_compressor . torch . algorithms . fp8_quant . prepare_quant .prepare_model as prepare_model
8683
8784 reference_model = WrapModel (module_class , seed , * module_args , ** module_kwargs )
8885 quantized_model = WrapModel (module_class , seed , * module_args , ** module_kwargs )
@@ -92,7 +89,7 @@ def run_accuracy_test(
9289 lp_dtype = lp_dtype ,
9390 scale_method = scale_method ,
9491 )
95- hqt ._prep_model_with_predefined_config (quantized_model , config = config )
92+ prepare_model ._prep_model_with_predefined_config (quantized_model , config = config )
9693
9794 _assert_quantized_correctly (reference_model = reference_model , quantized_model = quantized_model )
9895
@@ -120,7 +117,7 @@ def run_accuracy_test(
120117 f"\n { scale_method .name = } "
121118 )
122119
123- hqt .finish_measurements (quantized_model )
120+ prepare_model .finish_measurements (quantized_model )
124121
125122
126123def _set_optional_seed (* , module_class : typing .Type [M ], seed : typing .Optional [int ]):
0 commit comments