@@ -48,6 +48,7 @@ def __new__(
4848 tensor : torch .Tensor ,
4949 dtype : torch .dtype ,
5050 ):
51+ logger .info (f"ScaledGroupedMMTensor __new__: tensor.dtype={ tensor .dtype } , dtype: { dtype } , shape: { tensor .shape } " )
5152 return torch .Tensor ._make_wrapper_subclass (
5253 cls ,
5354 tensor .size (),
@@ -66,14 +67,13 @@ def __init__(
6667 tensor : torch .Tensor ,
6768 dtype : torch .dtype ,
6869 ):
70+ logger .info (f"ScaledGroupedMMTensor __init__: tensor.dtype={ tensor .dtype } , dtype: { dtype } , shape: { tensor .shape } " )
6971 self ._data = tensor .to (dtype )
7072 self ._dtype = dtype
7173
7274 @classmethod
7375 def __torch_function__ (cls , func , types , args , kwargs = {}):
74- logger .debug (
75- f"ScaledGroupedMMTensor func: { func .__name__ } , args: { args } , kwargs: { kwargs } "
76- )
76+ logger .info (f"ScaledGroupedMMTensor func: { func .__name__ } , args: { args } , kwargs: { kwargs } " )
7777 # override the grouped mm op to use the differentiable _scaled_grouped_mm
7878 if func .__name__ == cls .grouped_mm_func_name :
7979 # Use torchao scaled grouped mm with dynamic quant for
@@ -148,9 +148,7 @@ def fsdp_pre_all_gather(
148148 ):
149149 all_gather_inputs = (self ._data ,)
150150 all_gather_metadata = ()
151- logger .debug (
152- f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={ self ._data .dtype } , param_dtype: { mp_policy .param_dtype } "
153- )
151+ #logger.info(f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, self._data.shape={self._data.shape}, param_dtype: {mp_policy.param_dtype}")
154152 return all_gather_inputs , all_gather_metadata
155153
156154 def fsdp_post_all_gather (
@@ -162,9 +160,7 @@ def fsdp_post_all_gather(
162160 out : Optional [torch .Tensor ] = None ,
163161 ):
164162 (data ,) = all_gather_outputs
165- logger .debug (
166- f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={ data .dtype } , param_dtype: { param_dtype } "
167- )
163+ #logger.info(f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}")
168164
169165 if out is not None :
170166 return
0 commit comments