@@ -53,6 +53,35 @@ def short_str(self):
5353 return "axs"
5454
5555
56+ @dataclass
57+ class Float8TypeConfig :
58+ """
59+ Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
60+
61+ Currently, ROCm only supports fnuz variants.
62+ """
63+
64+ # The preferred e4m3 type.
65+ e4m3_dtype = torch .float8_e4m3fn
66+
67+ # The preferred e5m2 type.
68+ e5m2_dtype = torch .float8_e5m2
69+
70+ def __post_init__ (self ):
71+ if torch .version .hip and torch .cuda .is_available ():
72+ prop = torch .cuda .get_device_properties (0 )
73+ MI300_ARCH = ("gfx940" , "gfx941" , "gfx942" )
74+ if prop .gcnArchName .split (":" )[0 ] in MI300_ARCH :
75+ self .e4m3_dtype = torch .float8_e4m3fnuz
76+ self .e5m2_dtype = torch .float8_e5m2fnuz
77+
78+
79+ # User defined type for using the individual F8 type based on config
80+ type_config = Float8TypeConfig ()
81+ e4m3_dtype = type_config .e4m3_dtype
82+ e5m2_dtype = type_config .e5m2_dtype
83+
84+
5685@dataclass (frozen = True )
5786class CastConfig :
5887 """
@@ -62,9 +91,11 @@ class CastConfig:
6291 scaling_type : ScalingType = ScalingType .DYNAMIC
6392 scaling_granularity : ScalingGranularity = ScalingGranularity .TENSORWISE
6493 static_scale : Optional [torch .Tensor ] = None
94+ target_dtype : Optional [torch .dtype ] = None
6595
6696 def short_str (self ):
67- return f"{ self .scaling_type .short_str ()} _{ self .scaling_granularity .short_str ()} "
97+ dtype = {e4m3_dtype : "e4m3" , e5m2_dtype : "e5m2" }[self .target_dtype ]
98+ return f"{ self .scaling_type .short_str ()} _{ self .scaling_granularity .short_str ()} _{ dtype } "
6899
69100 def __post_init__ (self ):
70101 if self .scaling_type is ScalingType .STATIC :
@@ -75,6 +106,9 @@ def __post_init__(self):
75106 assert (
76107 self .scaling_type is ScalingType .DYNAMIC
77108 ), "only dynamic scaling type is supported for axiswise scaling granularity"
109+ assert self .target_dtype is None or (
110+ self .target_dtype .is_floating_point and self .target_dtype .itemsize == 1
111+ ), "must specify a 8-bit floating-point dtype"
78112
79113
80114@dataclass (frozen = True )
@@ -101,29 +135,6 @@ def __post_init__(self):
101135 ), f"{ self .scale_fn_name } is not implemented yet. Only max is supported for now."
102136
103137
104- @dataclass
105- class Float8TypeConfig :
106- """
107- Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
108-
109- Currently, ROCm only supports fnuz variants.
110- """
111-
112- # The preferred e4m3 type.
113- e4m3_dtype = torch .float8_e4m3fn
114-
115- # The preferred e5m2 type.
116- e5m2_dtype = torch .float8_e5m2
117-
118- def __post_init__ (self ):
119- if torch .version .hip and torch .cuda .is_available ():
120- prop = torch .cuda .get_device_properties (0 )
121- MI300_ARCH = ("gfx940" , "gfx941" , "gfx942" )
122- if prop .gcnArchName .split (":" )[0 ] in MI300_ARCH :
123- self .e4m3_dtype = torch .float8_e4m3fnuz
124- self .e5m2_dtype = torch .float8_e5m2fnuz
125-
126-
127138@dataclass (frozen = True )
128139class Float8GemmConfig :
129140 """
@@ -276,6 +287,20 @@ def __post_init__(self):
276287 is_disabled_1 == is_disabled_2
277288 ), f"incompatible operand precision for { gemm_name } "
278289
290+ for cc1 , cc2 , operand_name , default_dtype in [
291+ (cc_i , cc_i_gw , "input" , e4m3_dtype ),
292+ (cc_w , cc_w_gi , "weight" , e4m3_dtype ),
293+ (cc_go , cc_go_gw , "grad_output" , e5m2_dtype ),
294+ ]:
295+ # Override the dataclass being frozen
296+ if cc1 .target_dtype is None :
297+ object .__setattr__ (cc1 , "target_dtype" , default_dtype )
298+ if cc2 .target_dtype is None :
299+ object .__setattr__ (cc2 , "target_dtype" , default_dtype )
300+ assert (
301+ cc1 .target_dtype == cc2 .target_dtype
302+ ), f"{ operand_name } must be cast to the same dtype in both matmuls it's used in"
303+
279304 if self .use_fp8_all_gather_only :
280305 assert self .enable_fsdp_float8_all_gather , "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"
281306
@@ -334,18 +359,23 @@ def recipe_name_to_linear_config(
334359 # * `input`, `weight` and `grad_output` now only need to be scaled
335360 # axiswise across a single dim compared to vanilla all-axiswise,
336361 # which is more amenable to fast kernels
362+ # * the e4m3 dtype is used across the board, including for gradients
337363
338364 # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
339365 cc_i = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
340366 cc_w = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
341367
342368 # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
343- cc_go = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
369+ cc_go = CastConfig (
370+ scaling_granularity = ScalingGranularity .AXISWISE , target_dtype = e4m3_dtype
371+ )
344372 cc_w_gi = CastConfig (scaling_granularity = ScalingGranularity .TENSORWISE )
345373
346374 # grad_weight_hp = input_t_hp @ grad_output_hp
347375 cc_i_gw = CastConfig (scaling_type = ScalingType .DISABLED )
348- cc_go_gw = CastConfig (scaling_type = ScalingType .DISABLED )
376+ cc_go_gw = CastConfig (
377+ scaling_type = ScalingType .DISABLED , target_dtype = e4m3_dtype
378+ )
349379
350380 return Float8LinearConfig (
351381 cast_config_input = cc_i ,
0 commit comments