Skip to content

Commit 2ee53a9

Browse files
authored
Fix RTN supported layer checking condition [3.x] (#1706)
Signed-off-by: Kaihui-intel <[email protected]>
1 parent 0791776 commit 2ee53a9

File tree

1 file changed

+2
-2
lines changed
  • neural_compressor/torch/algorithms/weight_only

1 file changed

+2
-2
lines changed

neural_compressor/torch/algorithms/weight_only/rtn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def rtn_quantize(
8181
model.to(device)
8282

8383
assert isinstance(model, torch.nn.Module), "only support torch module"
84-
supported_layers = ["Linear"]
84+
supported_layers = (torch.nn.Linear,)
8585
# initialize global configuration
8686
double_quant_config = {
8787
"double_quant": kwargs.get("use_double_quant", False),
@@ -93,7 +93,7 @@ def rtn_quantize(
9393
if export_compressed_model:
9494
use_optimum_format = kwargs.get("use_optimum_format", True)
9595
for name, m in model.named_modules():
96-
if m.__class__.__name__ not in supported_layers:
96+
if not isinstance(m, supported_layers):
9797
continue
9898
if name in weight_config: # pragma: no cover
9999
# initialize op configuration

0 commit comments

Comments
 (0)