66import copy
77import logging
88from enum import auto , Enum
9- from typing import Callable , List , Optional , Type
9+ from typing import Callable , List , Optional , Type , Union
1010
1111import torch
1212import torch .distributed as dist
@@ -97,45 +97,51 @@ def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear],
9797 )
9898
9999
100- def swap_linear_with_float8_linear (
100+ def swap_linear_layers (
101101 module : nn .Module ,
102- module_cls : Type [ nn .Module ],
102+ from_float_func : Callable [[ nn .Linear ], nn . Linear ],
103103 * ,
104104 skip_fqn_list : Optional [List [str ]] = None ,
105- emulate : bool = False ,
106105 linear_layer_filter : Optional [Callable [[nn .Linear ], bool ]] = None ,
107- ) -> nn .Module :
106+ ) -> Optional [ nn .Module ] :
108107 """
109- Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
110- of ``module_cls`` (either ``Float8Linear`` or ``Float8DynamicLinear``).
108+ Generic function to swap linear layers in a module with a new type of linear layer.
109+
110+ Note:
111+ If applied to a root-level nn.Linear, the module will not be modified in place
112+ and returned instead
111113
112114 Args:
113- module (torch.nn.Module): Module to modify.
114- module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap.
115- skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
116- Linear submodules of these skipped modules will also be skipped.
117- emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
118- linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
115+ module: Module to modify.
116+ from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
117+ skip_fqn_list: If specified, a list of module FQNs to skip.
118+ linear_layer_filter: If specified, only the linear layers
119119 that pass the filter function will be swapped.
120+ from_float_kwargs: Additional keyword arguments for from_float_func.
121+
122+ Returns:
123+ nn.Module: The modified module with swapped linear layers.
120124 """
121125 module_names_to_skip = set (skip_fqn_list or [])
126+
122127 if isinstance (module , nn .Linear ) and (
123128 linear_layer_filter is None or linear_layer_filter (module )
124129 ):
125130 if len (list (module .children ())) > 0 :
126131 raise AssertionError (
127132 f"Does not support a root nn.Linear with children: { module } "
128133 )
129- return module_cls .from_float (module , emulate = emulate )
134+ return from_float_func (
135+ module ,
136+ )
130137
131- # Mark all modules to skip as visited
132138 root_module = module
133139 visited_modules = {root_module }
140+
134141 for module_name , module in root_module .named_modules ():
135142 if module_name in module_names_to_skip :
136143 visited_modules .add (module )
137144
138- # Run a post-order traversal to swap linears
139145 def post_order_traversal (
140146 module : nn .Module , module_name : str , parent_module : Optional [nn .Module ]
141147 ):
@@ -144,14 +150,15 @@ def post_order_traversal(
144150 if child_module not in visited_modules :
145151 visited_modules .add (child_module )
146152 post_order_traversal (child_module , child_module_name , module )
153+
147154 if isinstance (module , nn .Linear ) and (
148155 linear_layer_filter is None or linear_layer_filter (module )
149156 ):
150157 assert (
151158 parent_module is not None
152159 ), f"Linear root module should return early: { module } "
153- float8linear_module = module_cls . from_float (module , emulate = emulate )
154- setattr (parent_module , module_name , float8linear_module )
160+ new_linear_module = from_float_func (module )
161+ setattr (parent_module , module_name , new_linear_module )
155162
156163 post_order_traversal (root_module , "" , None )
157164 # Without this explicit `del`, this set only gets deleted upon an explicit
@@ -160,6 +167,22 @@ def post_order_traversal(
160167 return root_module
161168
162169
170+ def swap_linear_with_float8_linear (
171+ module : nn .Module ,
172+ module_cls : Union [Type [Float8Linear ], Type [Float8DynamicLinear ]],
173+ * ,
174+ skip_fqn_list : Optional [List [str ]] = None ,
175+ emulate : bool = False ,
176+ linear_layer_filter : Optional [Callable [[nn .Linear ], bool ]] = None ,
177+ ) -> Optional [nn .Module ]:
178+ return swap_linear_layers (
179+ module ,
180+ lambda m : module_cls .from_float (m , emulate = emulate ),
181+ skip_fqn_list = skip_fqn_list ,
182+ linear_layer_filter = linear_layer_filter ,
183+ )
184+
185+
163186def get_float8_layers (model : torch .nn .Module ):
164187 """Iterates through the model and returns all the Float8Linear layers.
165188 Args:
0 commit comments