|
18 | 18 | ScaledMMConfig, |
19 | 19 | ) |
20 | 20 |
|
21 | | -from float8_experimental.float8_utils import EPS |
| 21 | +from float8_experimental.float8_utils import e4m3_dtype, EPS |
22 | 22 | from torch._prims_common import suggest_memory_format |
23 | 23 |
|
24 | 24 |
|
@@ -189,3 +189,182 @@ def fsdp_post_all_gather( |
189 | 189 | out._scale = scale |
190 | 190 | return |
191 | 191 | return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) |
| 192 | + |
| 193 | + |
| 194 | +class WeightWithDelayedFloat8CastTensor(torch.Tensor): |
| 195 | + @staticmethod |
| 196 | + def __new__( |
| 197 | + cls, |
| 198 | + tensor: torch.Tensor, |
| 199 | + amax_buffer: torch.Tensor, |
| 200 | + amax_history_buffer: torch.Tensor, |
| 201 | + scale_buffer: torch.Tensor, |
| 202 | + mm_config: ScaledMMConfig, |
| 203 | + is_amax_initialized: bool, |
| 204 | + ): |
| 205 | + return torch.Tensor._make_wrapper_subclass( |
| 206 | + cls, |
| 207 | + tensor.size(), |
| 208 | + strides=tensor.stride(), |
| 209 | + storage_offset=tensor.storage_offset(), |
| 210 | + memory_format=suggest_memory_format(tensor), |
| 211 | + dtype=tensor.dtype, |
| 212 | + layout=tensor.layout, |
| 213 | + device=tensor.device, |
| 214 | + pin_memory=tensor.is_pinned(), |
| 215 | + requires_grad=tensor.requires_grad, |
| 216 | + ) |
| 217 | + |
| 218 | + def __init__( |
| 219 | + self, |
| 220 | + tensor: torch.Tensor, |
| 221 | + amax_buffer: torch.Tensor, |
| 222 | + amax_history_buffer: torch.Tensor, |
| 223 | + scale_buffer: torch.Tensor, |
| 224 | + mm_config: ScaledMMConfig, |
| 225 | + is_amax_initialized: bool, |
| 226 | + ): |
| 227 | + self._tensor = tensor |
| 228 | + self._amax_buffer = amax_buffer |
| 229 | + self._amax_history_buffer = amax_history_buffer |
| 230 | + self._scale_buffer = scale_buffer |
| 231 | + self._mm_config = mm_config |
| 232 | + |
| 233 | + # Note: is_amax_initialized is not a buffer to avoid data dependent |
| 234 | + # control flow visible to dynamo |
| 235 | + # TODO(future PR): add serialization for this flag |
| 236 | + self.is_amax_initialized = is_amax_initialized |
| 237 | + |
| 238 | + @classmethod |
| 239 | + def __torch_dispatch__(cls, func, types, args, kwargs=None): |
| 240 | + if func == torch.ops.aten.detach.default: |
| 241 | + return WeightWithDelayedFloat8CastTensor( |
| 242 | + args[0]._tensor, |
| 243 | + args[0]._amax_buffer, |
| 244 | + args[0]._amax_history_buffer, |
| 245 | + args[0]._scale_buffer, |
| 246 | + args[0]._mm_config, |
| 247 | + args[0].is_amax_initialized, |
| 248 | + ) |
| 249 | + mm_config: Optional[ScaledMMConfig] = None |
| 250 | + amax_buffer: Optional[torch.Tensor] = None |
| 251 | + amax_history_buffer: Optional[torch.Tensor] = None |
| 252 | + scale_buffer: Optional[torch.Tensor] = None |
| 253 | + is_amax_initialized: Optional[bool] = None |
| 254 | + |
| 255 | + def unwrap(t): |
| 256 | + nonlocal mm_config |
| 257 | + if mm_config is None: |
| 258 | + mm_config = t._mm_config |
| 259 | + else: |
| 260 | + mm_config = merge_mm_configs(mm_config, t._mm_config) |
| 261 | + nonlocal amax_buffer |
| 262 | + if amax_buffer is None: |
| 263 | + amax_buffer = t._amax_buffer |
| 264 | + nonlocal amax_history_buffer |
| 265 | + if amax_history_buffer is None: |
| 266 | + amax_history_buffer = t._amax_history_buffer |
| 267 | + nonlocal scale_buffer |
| 268 | + if scale_buffer is None: |
| 269 | + scale_buffer = t._scale_buffer |
| 270 | + nonlocal is_amax_initialized |
| 271 | + if is_amax_initialized is None: |
| 272 | + is_amax_initialized = t.is_amax_initialized |
| 273 | + return t._tensor |
| 274 | + |
| 275 | + args, kwargs = pytree.tree_map_only( |
| 276 | + WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {}) |
| 277 | + ) |
| 278 | + out = func(*args, **kwargs) |
| 279 | + if func not in _ops_to_preserve_subclass: |
| 280 | + return out |
| 281 | + return pytree.tree_map_only( |
| 282 | + torch.Tensor, |
| 283 | + lambda x: WeightWithDelayedFloat8CastTensor( |
| 284 | + x, |
| 285 | + amax_buffer, |
| 286 | + amax_history_buffer, |
| 287 | + scale_buffer, |
| 288 | + mm_config, |
| 289 | + is_amax_initialized, |
| 290 | + ), |
| 291 | + out, |
| 292 | + ) |
| 293 | + |
| 294 | + def __tensor_flatten__(self): |
| 295 | + return ( |
| 296 | + [ |
| 297 | + "_tensor", |
| 298 | + "_amax_buffer", |
| 299 | + "_amax_history_buffer", |
| 300 | + "_scale_buffer", |
| 301 | + ], |
| 302 | + { |
| 303 | + "mm_config": self._mm_config, |
| 304 | + "is_amax_initialized": is_amax_initialized, |
| 305 | + }, |
| 306 | + ) |
| 307 | + |
| 308 | + @staticmethod |
| 309 | + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): |
| 310 | + return WeightWithDelayedFloat8CastTensor( |
| 311 | + inner_tensors["_tensor"], |
| 312 | + inner_tensors["_amax_buffer"], |
| 313 | + inner_tensors["_amax_history_buffer"], |
| 314 | + inner_tensors["_scale_buffer"], |
| 315 | + metadata["mm_config"], |
| 316 | + metadata["is_amax_initialized"], |
| 317 | + ) |
| 318 | + |
| 319 | + def __repr__(self): |
| 320 | + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})" |
| 321 | + |
| 322 | + def fsdp_pre_all_gather(self, mesh): |
| 323 | + # initialize if needed |
| 324 | + # TODO(before land): ensure settings are consistent between Float8Linear and here |
| 325 | + if not self.is_amax_initialized: |
| 326 | + from float8_experimental.float8_linear import ( |
| 327 | + _maybe_initialize_amaxes_scales_for_float8_cast, |
| 328 | + ) |
| 329 | + |
| 330 | + _maybe_initialize_amaxes_scales_for_float8_cast( |
| 331 | + self._tensor, |
| 332 | + self._amax_buffer, |
| 333 | + self._amax_history_buffer, |
| 334 | + self._scale_buffer, |
| 335 | + "max", # TODO(before land): read this from parent |
| 336 | + e4m3_dtype, |
| 337 | + self.is_amax_initialized, |
| 338 | + reduce_amax=True, |
| 339 | + ) |
| 340 | + self.is_amax_initialized = True |
| 341 | + |
| 342 | + # this will: |
| 343 | + # 1. cast the tensor to float8 using `_scale_buffer` |
| 344 | + # 2. populate `_amax_buffer` inplace |
| 345 | + # TODO(future PR): clean up all the casting functions and clearly |
| 346 | + # separate dynamic vs delayed, tech debt has accumulated |
| 347 | + float8_tensor = Float8Tensor.to_float8( |
| 348 | + self._tensor, |
| 349 | + self._scale_buffer, |
| 350 | + e4m3_dtype, |
| 351 | + self._amax_buffer, |
| 352 | + self._mm_config, |
| 353 | + ) |
| 354 | + return (float8_tensor._data,), (float8_tensor._scale,) |
| 355 | + |
| 356 | + def fsdp_post_all_gather( |
| 357 | + self, |
| 358 | + all_gather_outputs: Tuple[torch.Tensor, ...], |
| 359 | + metadata: Any, |
| 360 | + param_dtype: torch.dtype, |
| 361 | + *, |
| 362 | + out: Optional[torch.Tensor] = None, |
| 363 | + ): |
| 364 | + (data,) = all_gather_outputs |
| 365 | + (scale,) = metadata |
| 366 | + if out is not None: |
| 367 | + assert isinstance(out, Float8Tensor), f"{type(out)}" |
| 368 | + out._scale = scale |
| 369 | + return |
| 370 | + return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) |
0 commit comments