File tree Expand file tree Collapse file tree 2 files changed +16
-2
lines changed 
torchao/prototype/moe_training Expand file tree Collapse file tree 2 files changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -40,7 +40,7 @@ def _scaled_grouped_mm(
4040        offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. 
4141        out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. 
4242    """ 
43-     logger .info ("Using scaled_grouped_mm" )
43+     #  logger.info("Using scaled_grouped_mm")
4444    return  _Float8GroupedMM .apply (
4545        A ,
4646        B_t ,
Original file line number Diff line number Diff line change @@ -47,7 +47,6 @@ def __new__(
4747        cls ,
4848        tensor : torch .Tensor ,
4949    ):
50-         # logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}") 
5150        return  torch .Tensor ._make_wrapper_subclass (
5251            cls ,
5352            tensor .size (),
@@ -155,9 +154,24 @@ def fsdp_post_all_gather(
155154    ):
156155        (data ,) =  all_gather_outputs 
157156
157+         # For training step 1+, out=unshared param, so we need to copy data to `out`` 
158+         # if `self._data`` and `out` do not share the same storage. 
159+         # Otherwise, if they do share the same storage, we can just return directly. 
158160        if  out  is  not   None :
161+             assert  isinstance (out , ScaledGroupedMMTensor ), f"{ type (out )}  " 
162+             if  data .dtype  ==  param_dtype :
163+                 assert  (
164+                     data .untyped_storage ().data_ptr ()
165+                     ==  out ._data .untyped_storage ().data_ptr ()
166+                 )
167+             else :
168+                 assert  out ._data .dtype  ==  param_dtype , (
169+                     f"{ out ._data .dtype }   { param_dtype }  " 
170+                 )
171+                 out ._data .copy_ (data )
159172            return 
160173
174+         # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor. 
161175        output  =  ScaledGroupedMMTensor (data )
162176        inner_tensors  =  (data ,)
163177        return  output , inner_tensors 
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments