@@ -116,7 +116,7 @@ def convert_to_float8_training(
116116 )
117117
118118
119- def auto_filter_for_recipe (
119+ def _auto_filter_for_recipe (
120120 recipe : Float8LinearRecipeName , filter_fqns : List [str ]
121121) -> Callable [[nn .Module , str ], bool ]:
122122 """Automatically filters nn.Linear modules that meet at least one of the following criteria:
@@ -127,7 +127,9 @@ def auto_filter_for_recipe(
127127 NOTE: the thresholds are simple heuristics based on performance testing, and may not be optimal
128128 for your model. For the best performance, we recommend defining your own module_filter_fn customized for
129129 your module, using the performance tables for the given float8 recipe here:
130- https://github.com/pytorch/ao/tree/main/torchao/float8#performance).
130+ https://github.com/pytorch/ao/tree/main/torchao/float8#performance). Note that the benchmarks referenced
131+ for auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
132+
131133
132134 The design of this function may change in the future.
133135 """
@@ -156,8 +158,10 @@ def _auto_filter_for_rowwise(mod: nn.Module, fqn: str, filter_fqns: List[str]) -
156158 if not dims_multiples_of_16 :
157159 return False
158160
159- # Dims below these thresholds will result in worse performance
161+ # Dims below these thresholds may result in worse performance
160162 # (see https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling)
163+ # Note that these benchmarks referenced for auto filtering layers were run on
164+ # H100 GPUs, and may not be representative of other hardware.
161165 if N <= 2048 :
162166 return False
163167 elif K <= 1024 :
@@ -184,8 +188,10 @@ def _auto_filter_for_tensorwise(
184188 if not dims_multiples_of_16 :
185189 return False
186190
187- # Dims below these thresholds will result in worse performance
191+ # Dims below these thresholds may result in worse performance
188192 # (see https://github.com/pytorch/ao/tree/main/torchao/float8#tensorwise-scaling)
193+ # Note that these benchmarks referenced for auto filtering layers were run on
194+ # H100 GPUs, and may not be representative of other hardware.
189195 if K <= 4096 and N <= 1024 :
190196 return False
191197 return True
0 commit comments