|
27 | 27 | from torch.fx import Node |
28 | 28 |
|
29 | 29 | import torchao |
| 30 | +from torchao.quantization import Granularity |
30 | 31 | from torchao.quantization.pt2e.utils import ( |
31 | 32 | calculate_qmin_qmax, |
32 | 33 | check_min_max_valid, |
|
67 | 68 | "ReuseInputObserver", |
68 | 69 | "UniformQuantizationObserverBase", |
69 | 70 | "AffineQuantizedObserverBase", |
70 | | - "Granularity", |
71 | 71 | "MappingType", |
72 | | - "PerAxis", |
73 | | - "PerBlock", |
74 | | - "PerGroup", |
75 | | - "PerRow", |
76 | | - "PerTensor", |
77 | | - "PerToken", |
78 | 72 | "TorchAODType", |
79 | 73 | "ZeroPointDomain", |
80 | | - "get_block_size", |
81 | 74 | ] |
82 | 75 |
|
83 | 76 |
|
@@ -1622,7 +1615,6 @@ def calculate_qparams(self): |
1622 | 1615 | We plan to merge the following with torchao repo after we move pt2e flow to torchao |
1623 | 1616 | copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py |
1624 | 1617 | """ |
1625 | | -from dataclasses import dataclass |
1626 | 1618 | from enum import Enum, auto |
1627 | 1619 |
|
1628 | 1620 |
|
@@ -1679,139 +1671,6 @@ class TorchAODType(Enum): |
1679 | 1671 | INT7 = auto() |
1680 | 1672 |
|
1681 | 1673 |
|
1682 | | -@dataclass(frozen=True) |
1683 | | -class Granularity: |
1684 | | - """ |
1685 | | - Base class for representing the granularity of quantization. |
1686 | | -
|
1687 | | - This class serves as a parent for specific granularity types used in |
1688 | | - quantization operations, such as per-tensor or per-axis quantization. |
1689 | | - """ |
1690 | | - |
1691 | | - |
1692 | | -@dataclass(frozen=True) |
1693 | | -class PerBlock(Granularity): |
1694 | | - """ |
1695 | | - Represents per-block granularity in quantization. See |
1696 | | - :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for |
1697 | | - `block_size` |
1698 | | -
|
1699 | | - Attributes: |
1700 | | - block_size (Tuple[int, ...]): The size of each quantization group |
1701 | | - """ |
1702 | | - |
1703 | | - block_size: tuple[int, ...] |
1704 | | - |
1705 | | - |
1706 | | -@dataclass(frozen=True) |
1707 | | -class PerTensor(Granularity): |
1708 | | - """ |
1709 | | - Represents per-tensor granularity in quantization. |
1710 | | -
|
1711 | | - This granularity type calculates the quantization parameters |
1712 | | - based off the entire tensor. |
1713 | | -
|
1714 | | - """ |
1715 | | - |
1716 | | - |
1717 | | -@dataclass(frozen=True) |
1718 | | -class PerAxis(Granularity): |
1719 | | - """ |
1720 | | - Represents per-axis granularity in quantization. |
1721 | | -
|
1722 | | - This granularity type calculates different quantization parameters |
1723 | | - along a specified axis of the tensor. |
1724 | | -
|
1725 | | - For example if the input tensor is shape [8, 16] and axis=0, then |
1726 | | - the quantization parameters are calculated for each row of the tensor. |
1727 | | - Giving a total of 8 quantization parameters. |
1728 | | -
|
1729 | | - Attributes: |
1730 | | - axis (int): The axis along which reduction is performed. |
1731 | | - """ |
1732 | | - |
1733 | | - axis: int |
1734 | | - |
1735 | | - |
1736 | | -@dataclass(frozen=True) |
1737 | | -class PerGroup(Granularity): |
1738 | | - """ |
1739 | | - Represents per-channel group granularity in quantization. |
1740 | | -
|
1741 | | - This granularity type calculates different quantization parameters |
1742 | | - for each group of <group_size> elements. |
1743 | | -
|
1744 | | - For example if the input tensor is shape [8, 16], and the group size is 4, then |
1745 | | - the input tensor is reshaped to [64, 4] |
1746 | | - quantization parameters are calculated for each group of 4 elements, |
1747 | | - giving a total of 64 quantization parameters. |
1748 | | -
|
1749 | | - Attributes: |
1750 | | - group_size (int): The size of each quantization group |
1751 | | -
|
1752 | | - """ |
1753 | | - |
1754 | | - group_size: int |
1755 | | - |
1756 | | - |
1757 | | -class PerRow(Granularity): |
1758 | | - """ |
1759 | | - Represents row-wise granularity in quantization. |
1760 | | -
|
1761 | | - This is a special case of per-axis quantization and is unique to Float8 matmuls |
1762 | | - where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight |
1763 | | - is quantized with a block_size of (1, weight.shape[1]). |
1764 | | - """ |
1765 | | - |
1766 | | - |
1767 | | -class PerToken(Granularity): |
1768 | | - """ |
1769 | | - Represents per-token granularity in quantization. |
1770 | | -
|
1771 | | - This granularity type calculates a different set of quantization parameters |
1772 | | - for each token, which is represented as the last dimension of the tensor. |
1773 | | -
|
1774 | | - For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens |
1775 | | - with 4 elements each, and we will calculate 6 sets of quantization parameters, |
1776 | | - one for each token. |
1777 | | -
|
1778 | | - If the input tensor has only two dimensions, e.g. [8, 16], then this is |
1779 | | - equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. |
1780 | | - """ |
1781 | | - |
1782 | | - |
1783 | | -def get_block_size( |
1784 | | - input_shape: tuple[int, ...], granularity: Granularity |
1785 | | -) -> tuple[int, ...]: |
1786 | | - """Get the block size based on the input shape and granularity type. |
1787 | | -
|
1788 | | - Args: |
1789 | | - input_shape: The input tensor shape possibly more than 2 dimensions |
1790 | | - granularity: The granularity type of the quantization |
1791 | | - """ |
1792 | | - assert isinstance(granularity, Granularity), ( |
1793 | | - "Please provide an instance of Granularity, not subclass of it" |
1794 | | - ) |
1795 | | - if isinstance(granularity, PerTensor): |
1796 | | - return input_shape |
1797 | | - elif isinstance(granularity, PerAxis): |
1798 | | - block_size = list(input_shape) |
1799 | | - block_size[granularity.axis] = 1 |
1800 | | - return tuple(block_size) |
1801 | | - elif isinstance(granularity, PerRow): |
1802 | | - return (1,) * (len(input_shape) - 1) + (input_shape[-1],) |
1803 | | - elif isinstance(granularity, PerGroup): |
1804 | | - assert len(input_shape) == 2, ( |
1805 | | - f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" |
1806 | | - ) |
1807 | | - return (1, granularity.group_size) |
1808 | | - elif isinstance(granularity, PerToken): |
1809 | | - block_size = [1] * len(input_shape) |
1810 | | - block_size[-1] = input_shape[-1] |
1811 | | - return tuple(block_size) |
1812 | | - raise ValueError(f"Unsupported Granularity: {granularity}") |
1813 | | - |
1814 | | - |
1815 | 1674 | class AffineQuantizedObserverBase(ABC, torch.nn.Module): |
1816 | 1675 | """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) |
1817 | 1676 |
|
|
0 commit comments