|
23 | 23 | preprocess_scale, |
24 | 24 | ) |
25 | 25 | from torchao.quantization.granularity import PerRow, PerTensor |
| 26 | +from torchao.quantization.utils import get_block_size |
26 | 27 | from torchao.quantization.quant_primitives import ( |
27 | 28 | _choose_scale_float8, |
28 | 29 | _dequantize_affine_float8, |
|
33 | 34 | QuantizeTensorKwargs, |
34 | 35 | _choose_quant_func_and_quantize_tensor, |
35 | 36 | ) |
36 | | -from torchao.quantization.utils import get_block_size |
37 | 37 | from torchao.utils import ( |
38 | 38 | TorchAOBaseTensor, |
39 | 39 | _is_fbgemm_genai_gpu_available, |
@@ -617,28 +617,6 @@ def _(func, types, args, kwargs): |
617 | 617 | return return_and_correct_aliasing(func, args, kwargs, new) |
618 | 618 |
|
619 | 619 |
|
620 | | -@implements(aten.select.int) |
621 | | -def _(func, types, args, kwargs): |
622 | | - old_float8_tensor, dim, index = args |
623 | | - assert dim == 0, f"Float8Tensor aten.select.int with {dim=} is not yet supported" |
624 | | - assert len(old_float8_tensor.qdata.shape) == len(old_float8_tensor.scale.shape), ( |
625 | | - "unsupported" |
626 | | - ) |
627 | | - assert len(old_float8_tensor.qdata.shape) == len(old_float8_tensor.block_size), ( |
628 | | - "unsupported" |
629 | | - ) |
630 | | - new_float8_tensor = old_float8_tensor.__class__( |
631 | | - old_float8_tensor.qdata[index], |
632 | | - old_float8_tensor.scale[index], |
633 | | - old_float8_tensor.block_size[1:], |
634 | | - old_float8_tensor.mm_config, |
635 | | - old_float8_tensor.act_quant_kwargs, |
636 | | - old_float8_tensor.kernel_preference, |
637 | | - old_float8_tensor.dtype, |
638 | | - ) |
639 | | - return return_and_correct_aliasing(func, args, kwargs, new_float8_tensor) |
640 | | - |
641 | | - |
642 | 620 | Float8Tensor.__module__ = "torchao.quantization" |
643 | 621 |
|
644 | 622 | # Allow a model with Float8Tensor weights to be loaded with `weights_only=True` |
|
0 commit comments