@@ -72,12 +72,6 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
7272 dist .broadcast (global_inp , src = 0 )
7373 return global_inp .view (self .world_size , - 1 )[self .rank ].view (16 , 16 )
7474
75- def swap_linear_with_dynamic (self , module : nn .Module , ** kwargs : Any ) -> nn .Module :
76- kwargs ["scaling_type_x" ] = TensorScalingType .DYNAMIC
77- kwargs ["scaling_type_w" ] = TensorScalingType .DYNAMIC
78- kwargs ["scaling_type_dL_dY" ] = TensorScalingType .DYNAMIC
79- return swap_linear_with_float8_linear (module , ** kwargs )
80-
8175
8276class TestFloat8MultiProcess (FSDPTest , TestFloat8Common ):
8377 @property
@@ -106,11 +100,11 @@ def _test_transformer_parity_dynamic(
106100 # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
107101 # fp8 for that tied weight, incorrectly using fp8 for the embedding.
108102 weight_tying = not enable_fsdp_fp8_all_gather
109- module = self .init_transformer (weight_tying = weight_tying )
103+ module = self .init_transformer (weight_tying = weight_tying ). cuda ()
110104 ref_module = copy .deepcopy (module )
111- ref_module = self . swap_linear_with_dynamic (ref_module ). cuda ( )
105+ swap_linear_with_float8_linear (ref_module )
112106 with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
113- module = self . swap_linear_with_dynamic (module )
107+ swap_linear_with_float8_linear (module )
114108 for submodule in module .modules ():
115109 if isinstance (submodule , TransformerBlock ):
116110 fully_shard (submodule )
@@ -153,7 +147,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
153147 # Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
154148 # requirement to use a smaller activation size
155149 with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
156- model = self . swap_linear_with_dynamic (model , emulate = True )
150+ swap_linear_with_float8_linear (model , emulate = True )
157151 model_unsharded_numel = sum (p .numel () for p in model .parameters ())
158152 model_sharded_numel = (model_unsharded_numel + 1 ) // 2
159153 block_lin_weight_numel = 0
@@ -331,7 +325,8 @@ def get_expected_all_gather_size(module: nn.Module):
331325 module_fp32 = self .init_single_module ()
332326 ref_module = copy .deepcopy (module_fp32 )
333327 with set_enable_fsdp_fp8_all_gather (True ):
334- module = self .swap_linear_with_dynamic (module_fp32 )
328+ module_fp32 = swap_linear_with_float8_linear (module_fp32 )
329+ module = module_fp32
335330 fully_shard (module )
336331 local_inp = self .get_local_inp ()
337332 expected_all_gather_size = get_expected_all_gather_size (ref_module )
@@ -359,7 +354,7 @@ def get_expected_all_gather_size(module: nn.Module):
359354 module = self .init_multi_module ()
360355 ref_module = copy .deepcopy (module )
361356 with set_enable_fsdp_fp8_all_gather (True ):
362- module = self . swap_linear_with_dynamic (module )
357+ module = swap_linear_with_float8_linear (module )
363358 for submodule in module :
364359 fully_shard (submodule )
365360 fully_shard (module )
@@ -383,10 +378,11 @@ def test_fp32_fp8_single_module_parity(self):
383378 """
384379 for enable_fsdp_fp8_all_gather in [False , True ]:
385380 module_fp32 = self .init_single_module ()
386- ref_module = self .swap_linear_with_dynamic (copy .deepcopy (module_fp32 ))
381+ ref_module = copy .deepcopy (module_fp32 )
382+ ref_module = swap_linear_with_float8_linear (ref_module )
387383 ref_module = ref_module .cuda ()
388384 with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
389- module = self . swap_linear_with_dynamic (module_fp32 )
385+ module = swap_linear_with_float8_linear (module_fp32 )
390386 fully_shard (module )
391387 ref_optim = torch .optim .Adam (ref_module .parameters (), lr = 1e-2 )
392388 optim = torch .optim .Adam (module .parameters (), lr = 1e-2 , foreach = True )
@@ -407,11 +403,11 @@ def test_fp32_fp8_multi_module_parity(self):
407403 multiple modules/FSDP communication groups.
408404 """
409405 for enable_fsdp_fp8_all_gather in [False , True ]:
410- module = self .init_multi_module ()
406+ module = self .init_multi_module (). cuda ()
411407 ref_module = copy .deepcopy (module )
412- ref_module = self . swap_linear_with_dynamic (ref_module ). cuda ( )
408+ ref_module = swap_linear_with_float8_linear (ref_module )
413409 with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
414- module = self . swap_linear_with_dynamic (module )
410+ module = swap_linear_with_float8_linear (module )
415411 for submodule in module :
416412 fully_shard (submodule )
417413 fully_shard (module )
0 commit comments