55# LICENSE file in the root directory of this source tree.
66import logging
77from functools import partial
8- from typing import Callable , List , Optional
8+ from typing import Callable , List , Optional , Union
99
1010import torch .nn as nn
1111
@@ -117,27 +117,28 @@ def convert_to_float8_training(
117117
118118
119119def _auto_filter_for_recipe (
120- recipe : Float8LinearRecipeName , filter_fqns : List [str ]
120+ recipe : Union [ str , Float8LinearRecipeName ] , filter_fqns : List [str ]
121121) -> Callable [[nn .Module , str ], bool ]:
122- """Automatically filters nn.Linear modules that meet at least one of the following criteria:
122+ """Returns function which automatically filters nn.Linear modules that meet at least one of the following criteria:
123123
124124 1. Dims not divisible by 16 (hardware requirement for float8).
125- 2. Dim sizes below certain thresholds, which will result in worse performance.
125+ 2. Dim sizes below certain thresholds, which may result in worse performance.
126126
127127 NOTE: the thresholds are simple heuristics based on performance testing, and may not be optimal
128128 for your model. For the best performance, we recommend defining your own module_filter_fn customized for
129129 your module, using the performance tables for the given float8 recipe here:
130- https://github.com/pytorch/ao/tree/main/torchao/float8#performance). Note that the benchmarks referenced
131- for auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
130+ https://github.com/pytorch/ao/tree/main/torchao/float8#performance). These benchmarks referenced for
131+ auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
132132
133-
134- The design of this function may change in the future.
133+ This is an experimental API, the design may change in the future.
135134 """
136- if recipe == Float8LinearRecipeName .TENSORWISE .value :
135+ if isinstance (recipe , str ):
136+ recipe = Float8LinearRecipeName (recipe )
137+ if recipe == Float8LinearRecipeName .TENSORWISE :
137138 return partial (_auto_filter_for_tensorwise , filter_fqns = filter_fqns )
138- elif recipe == Float8LinearRecipeName .ROWWISE . value :
139+ elif recipe == Float8LinearRecipeName .ROWWISE :
139140 return partial (_auto_filter_for_rowwise , filter_fqns = filter_fqns )
140- elif recipe == Float8LinearRecipeName .ROWWISE_WITH_GW_HP . value :
141+ elif recipe == Float8LinearRecipeName .ROWWISE_WITH_GW_HP :
141142 raise NotImplementedError (f"Unsupported recipe: { recipe } " )
142143 else :
143144 raise ValueError (f"Invalid recipe: { recipe } " )
@@ -153,7 +154,7 @@ def _auto_filter_for_rowwise(mod: nn.Module, fqn: str, filter_fqns: List[str]) -
153154 return False
154155
155156 # All dims must be divisible by 16 due to float8 hardware requirements.
156- K , N = mod .weight .shape
157+ N , K = mod .weight .shape
157158 dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
158159 if not dims_multiples_of_16 :
159160 return False
@@ -183,7 +184,7 @@ def _auto_filter_for_tensorwise(
183184 return False
184185
185186 # All dims must be divisible by 16 due to float8 hardware requirements.
186- K , N = mod .weight .shape
187+ N , K = mod .weight .shape
187188 dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
188189 if not dims_multiples_of_16 :
189190 return False
0 commit comments