@@ -50,7 +50,11 @@ def decorator(func):
5050def  float8_desugar_op (aten_op , args , kwargs = None ):
5151    new_data  =  aten_op (args [0 ]._data , * args [1 :], ** kwargs )
5252    return  Float8Tensor (
53-         new_data , args [0 ]._scale , args [0 ]._orig_dtype , args [0 ]._mm_config 
53+         new_data ,
54+         args [0 ]._scale ,
55+         args [0 ]._orig_dtype ,
56+         args [0 ]._mm_config ,
57+         args [0 ]._scaling_strategy ,
5458    )
5559
5660
@@ -60,7 +64,11 @@ def float8_split(aten_op, args, kwargs=None):
6064
6165    def  make_float8 (data ):
6266        return  Float8Tensor (
63-             data , args [0 ]._scale , args [0 ]._orig_dtype , args [0 ]._mm_config 
67+             data ,
68+             args [0 ]._scale ,
69+             args [0 ]._orig_dtype ,
70+             args [0 ]._mm_config ,
71+             args [0 ]._scaling_strategy ,
6472        )
6573
6674    out  =  map (make_float8 , new_data_tensors )
@@ -75,6 +83,7 @@ def float8_cat(aten_op, args, kwargs=None):
7583    orig_dtype  =  chunked_tensors [0 ]._orig_dtype 
7684    scale  =  chunked_tensors [0 ]._scale 
7785    mm_config  =  chunked_tensors [0 ]._mm_config 
86+     scaling_strategy  =  chunked_tensors [0 ]._scaling_strategy 
7887    fp8_dtype  =  chunked_tensors [0 ]._data .dtype 
7988    chunk_data  =  []
8089    for  chunk  in  chunked_tensors :
@@ -93,11 +102,14 @@ def float8_cat(aten_op, args, kwargs=None):
93102        assert  (
94103            chunk ._data .dtype  ==  fp8_dtype 
95104        ), "Expecting all chunks to be of the same dtype as a result of a split" 
105+         assert  (
106+             chunk ._scaling_strategy  is  scaling_strategy 
107+         ), "Expecting all chunks to have thee same scaling strategy as a result of a split" 
96108        chunk_data .append (chunk ._data .view (torch .uint8 ))
97109
98110    new_data  =  aten_op (chunk_data , * args [1 :], ** kwargs )
99111    new_data  =  new_data .view (fp8_dtype )
100-     return  Float8Tensor (new_data , scale , orig_dtype , mm_config )
112+     return  Float8Tensor (new_data , scale , orig_dtype , mm_config ,  scaling_strategy )
101113
102114
103115@implements ([aten .sum .dim_IntList ]) 
@@ -162,6 +174,11 @@ def float8_mm(aten_op, args, kwargs=None):
162174        return  torch .ops .aten .mm_float8_emulated (
163175            a ._data , a ._scale , b ._data , b ._scale , output_dtype 
164176        )
177+     scaling_strategy  =  a ._scaling_strategy 
178+     # TODO We can enable this by broadcasting to the more generic form 
179+     assert  (
180+         scaling_strategy  ==  b ._scaling_strategy 
181+     ), "Scaling strategy are currently required to be the same" 
165182    tensor_out  =  addmm_float8_unwrapped (
166183        a_data ,
167184        a_scale ,
@@ -191,6 +208,11 @@ def float8_addmm(aten_op, args, kwargs=None):
191208    a_mm_config : ScaledMMConfig  =  a ._mm_config 
192209    b_mm_config : ScaledMMConfig  =  b ._mm_config 
193210    mm_config : ScaledMMConfig  =  merge_mm_configs (a_mm_config , b_mm_config )
211+     scaling_strategy  =  a ._scaling_strategy 
212+     # TODO We can enable this by broadcasting to the more generic form 
213+     assert  (
214+         scaling_strategy  ==  b ._scaling_strategy 
215+     ), "Scaling strategy are currently required to be the same" 
194216    if  mm_config .emulate :
195217        out  =  torch .ops .aten .mm_float8_emulated (
196218            a ._data , a ._scale , b ._data , b ._scale , output_dtype 
@@ -229,7 +251,11 @@ def autocast_to_copy(aten_op, args, kwargs=None):
229251        torch .bfloat16 ,
230252    }, "Only support floating point conversion for autocast w/ Float8Tensor" 
231253    return  Float8Tensor (
232-         args [0 ]._data , args [0 ]._scale , kwargs ["dtype" ], args [0 ]._mm_config 
254+         args [0 ]._data ,
255+         args [0 ]._scale ,
256+         kwargs ["dtype" ],
257+         args [0 ]._mm_config ,
258+         args [0 ]._scaling_strategy ,
233259    )
234260
235261
@@ -252,7 +278,11 @@ def allgather_fp8(aten_op, args, kwargs=None):
252278    fp8_data  =  fp8_data .contiguous ()
253279    fp8_out  =  aten_op (fp8_data , * args [1 :], ** kwargs )
254280    return  Float8Tensor (
255-         fp8_out , fp8_input ._scale , fp8_input ._orig_dtype , fp8_input ._mm_config 
281+         fp8_out ,
282+         fp8_input ._scale ,
283+         fp8_input ._orig_dtype ,
284+         fp8_input ._mm_config ,
285+         fp8_input ._scaling_strategy ,
256286    )
257287
258288
@@ -264,7 +294,11 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
264294    fp8_data  =  fp8_input ._data 
265295    fp8_out  =  aten_op (fp8_data , * args [1 :], ** kwargs )
266296    return  Float8Tensor (
267-         fp8_out , fp8_input ._scale , fp8_input ._orig_dtype , fp8_input ._mm_config 
297+         fp8_out ,
298+         fp8_input ._scale ,
299+         fp8_input ._orig_dtype ,
300+         fp8_input ._mm_config ,
301+         fp8_input ._scaling_strategy ,
268302    )
269303
270304
@@ -282,7 +316,11 @@ def index_put_fp8(aten_op, args, kwargs=None):
282316    fp8_values_data  =  fp8_values ._data 
283317    fp8_out  =  aten_op (fp8_data , args [1 ], fp8_values_data , * args [3 :], ** kwargs )
284318    return  Float8Tensor (
285-         fp8_out , fp8_self ._scale , fp8_self ._orig_dtype , fp8_self ._mm_config 
319+         fp8_out ,
320+         fp8_self ._scale ,
321+         fp8_self ._orig_dtype ,
322+         fp8_self ._mm_config ,
323+         fp8_self ._scaling_strategy ,
286324    )
287325
288326
@@ -315,6 +353,12 @@ def copy_fp8(aten_op, args, kwargs=None):
315353            self ._data .dtype  ==  src ._data .dtype 
316354        ), "Expecting both Float8Tensors to be of the same dtypet" 
317355        fp8_out  =  aten_op (self ._data , src ._data , * args [2 :], ** kwargs )
318-         return  Float8Tensor (fp8_out , self ._scale , self ._orig_dtype , self ._mm_config )
356+         return  Float8Tensor (
357+             fp8_out ,
358+             self ._scale ,
359+             self ._orig_dtype ,
360+             self ._mm_config ,
361+             self ._scaling_strategy ,
362+         )
319363    else :
320364        raise  RuntimeError ("Unsupported semantics for copy_ in Float8Tensor" )
0 commit comments