| 
4 | 4 | # This source code is licensed under the license found in the  | 
5 | 5 | # LICENSE file in the root directory of this source tree.  | 
6 | 6 | 
 
  | 
 | 7 | +import logging  | 
7 | 8 | from typing import Optional, Tuple  | 
8 | 9 | 
 
  | 
9 | 10 | import numpy as np  | 
 | 
35 | 36 |     F32_EXP_BIAS,  | 
36 | 37 | )  | 
37 | 38 | 
 
  | 
 | 39 | +logger = logging.getLogger(__name__)  | 
 | 40 | + | 
38 | 41 | 
 
  | 
39 | 42 | def get_bits(x: torch.Tensor) -> str:  | 
40 | 43 |     bits_per_byte = 8  | 
@@ -1476,10 +1479,20 @@ def triton_quantize_nvfp4(  | 
1476 | 1479 |         raise AssertionError("needs torch version 2.8+ and triton")  | 
1477 | 1480 | 
 
  | 
1478 | 1481 | 
 
  | 
1479 |  | -# MXFP8 CUDA kernel is only built on SM100+  | 
 | 1482 | +mxfp8_cuda_extension_available = False  | 
1480 | 1483 | if is_sm_at_least_100():  | 
1481 |  | -    from torchao.prototype import mxfp8_cuda  | 
1482 |  | - | 
 | 1484 | +    try:  | 
 | 1485 | +        # MXFP8 CUDA kernel is only built on SM100+. Furthermore,  | 
 | 1486 | +        # currently our CI runners are not SM100+, so the user needs to build  | 
 | 1487 | +        # from source.  | 
 | 1488 | +        # TODO(#2932): improve this  | 
 | 1489 | +        from torchao.prototype import mxfp8_cuda  | 
 | 1490 | + | 
 | 1491 | +        mxfp8_cuda_extension_available = True  | 
 | 1492 | +    except ImportError:  | 
 | 1493 | +        logging.debug("Skipping import of torchao.prototype.mxfp8_cuda")  | 
 | 1494 | + | 
 | 1495 | +if mxfp8_cuda_extension_available:  | 
1483 | 1496 |     # TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string.  | 
1484 | 1497 |     # Currently we have to use an arbitrary string because custom ops don't support enum  | 
1485 | 1498 |     # params.  | 
@@ -1599,4 +1612,6 @@ def mxfp8_quantize_cuda(  | 
1599 | 1612 |         colwise: bool = True,  | 
1600 | 1613 |         scaling_mode: str = "floor",  | 
1601 | 1614 |     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:  | 
1602 |  | -        raise NotImplementedError("needs torch version 2.8+ and sm100")  | 
 | 1615 | +        raise NotImplementedError(  | 
 | 1616 | +            "`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details."  | 
 | 1617 | +        )  | 
0 commit comments