@@ -99,13 +99,13 @@ def short_str(self):
9999
100100    def  __post_init__ (self ):
101101        if  self .scaling_type  is  ScalingType .STATIC :
102-             assert  (
103-                 self . static_scale  is   not   None 
104-             ),  "static_scale must be specified for static scaling" 
102+             assert  self . static_scale   is   not   None ,  (
103+                 " static_scale must be specified for static scaling" 
104+             )
105105        if  self .scaling_granularity  is  ScalingGranularity .AXISWISE :
106-             assert  (
107-                 self . scaling_type   is  ScalingType . DYNAMIC 
108-             ),  "only dynamic scaling type is supported for axiswise scaling granularity" 
106+             assert  self . scaling_type   is   ScalingType . DYNAMIC ,  (
107+                 "only dynamic scaling type  is supported for axiswise scaling granularity" 
108+             )
109109        assert  self .target_dtype  is  None  or  (
110110            self .target_dtype .is_floating_point  and  self .target_dtype .itemsize  ==  1 
111111        ), "must specify a 8-bit floating-point dtype" 
@@ -130,9 +130,9 @@ class DelayedScalingConfig:
130130    scale_fn_name : str  =  "max" 
131131
132132    def  __post_init__ (self ):
133-         assert  (
134-             self .scale_fn_name   ==   " max"
135-         ),  f" { self . scale_fn_name }  is not implemented yet. Only max is supported for now." 
133+         assert  self . scale_fn_name   ==   "max" ,  (
134+             f" { self .scale_fn_name }  is not implemented yet. Only  max is supported for now. "
135+         )
136136
137137
138138@dataclass (frozen = True ) 
@@ -148,7 +148,6 @@ class Float8GemmConfig:
148148
149149# Pre-made recipes for common configurations 
150150class  Float8LinearRecipeName (enum .Enum ):
151- 
152151    # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel 
153152    TENSORWISE  =  "tensorwise" 
154153
@@ -291,7 +290,9 @@ def __post_init__(self):
291290
292291        # float8 all-gather only supports tensorwise, in the future may support blockwise 
293292        if  self .cast_config_weight .scaling_granularity  !=  ScalingGranularity .TENSORWISE :
294-             assert  not  self .enable_fsdp_float8_all_gather , f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got { self .cast_config_weight .scaling_granularity }  " 
293+             assert  not  self .enable_fsdp_float8_all_gather , (
294+                 f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got { self .cast_config_weight .scaling_granularity }  " 
295+             )
295296
296297        # save some characters in the compatibility checks below 
297298        cc_i  =  self .cast_config_input 
@@ -310,9 +311,9 @@ def __post_init__(self):
310311        ):
311312            is_disabled_1  =  cc1 .scaling_type  is  ScalingType .DISABLED 
312313            is_disabled_2  =  cc1 .scaling_type  is  ScalingType .DISABLED 
313-             assert  (
314-                 is_disabled_1   ==   is_disabled_2 
315-             ),  f"incompatible operand precision for  { gemm_name } " 
314+             assert  is_disabled_1   ==   is_disabled_2 ,  (
315+                 f"incompatible operand precision for  { gemm_name } " 
316+             )
316317
317318        for  cc1 , cc2 , operand_name , default_dtype  in  [
318319            (cc_i , cc_i_gw , "input" , e4m3_dtype ),
@@ -324,9 +325,9 @@ def __post_init__(self):
324325                object .__setattr__ (cc1 , "target_dtype" , default_dtype )
325326            if  cc2 .target_dtype  is  None :
326327                object .__setattr__ (cc2 , "target_dtype" , default_dtype )
327-             assert  (
328-                 cc1 . target_dtype   ==   cc2 . target_dtype 
329-             ),  f" { operand_name }  must be cast to the same dtype in both matmuls it's used in" 
328+             assert  cc1 . target_dtype   ==   cc2 . target_dtype ,  (
329+                 f" { operand_name }  must be cast to the same dtype in both matmuls it's used in" 
330+             )
330331
331332        # See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning. 
332333        if  (
@@ -357,9 +358,9 @@ def from_recipe_name(
357358        """ 
358359        if  type (recipe_name ) ==  str :
359360            valid_names  =  [n .value  for  n  in  Float8LinearRecipeName ]
360-             assert  (
361-                 recipe_name  in  valid_names 
362-             ),  f"recipe_name  { recipe_name }  not in valid names  { valid_names } " 
361+             assert  recipe_name   in   valid_names ,  (
362+                 f" recipe_name { recipe_name }  not  in valid names  { valid_names } " 
363+             )
363364            recipe_name  =  Float8LinearRecipeName (recipe_name )
364365
365366        if  recipe_name  is  Float8LinearRecipeName .TENSORWISE :
@@ -385,7 +386,6 @@ def from_recipe_name(
385386            )
386387
387388        elif  recipe_name  is  Float8LinearRecipeName .ROWWISE_WITH_GW_HP :
388- 
389389            # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 
390390            cc_i  =  CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
391391            cc_w  =  CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
0 commit comments