|
23 | 23 | sync_float8_amax_and_scale_history, |
24 | 24 | ) |
25 | 25 | from float8_experimental.float8_python_api import addmm_float8_unwrapped |
26 | | -from float8_experimental.float8_tensor import Float8Tensor |
| 26 | +from float8_experimental.float8_tensor import ( |
| 27 | + Float8Tensor, |
| 28 | + merge_mm_configs, |
| 29 | + ScaledMMConfig, |
| 30 | +) |
27 | 31 | from float8_experimental.float8_utils import ( |
28 | 32 | amax_to_scale, |
29 | 33 | compute_error, |
@@ -326,6 +330,43 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): |
326 | 330 | atol, rtol = 2e-3, 2e-3 |
327 | 331 | torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) |
328 | 332 |
|
| 333 | + @unittest.skipIf(not is_H100, "CUDA not available") |
| 334 | + def test_different_configs_error(self): |
| 335 | + x_fp32 = torch.randn(16, 16, device="cuda") |
| 336 | + x_scale = torch.tensor(1.0, device="cuda") |
| 337 | + fp8_dtype = torch.float8_e4m3fn |
| 338 | + a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) |
| 339 | + b = Float8Tensor.to_float8( |
| 340 | + x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True) |
| 341 | + ) |
| 342 | + with pytest.raises( |
| 343 | + AssertionError, |
| 344 | + match="Both mm_configs must have the same emulate value, but got False and True", |
| 345 | + ): |
| 346 | + a @ b |
| 347 | + |
| 348 | + def test_merge_configs(sel): |
| 349 | + a = ScaledMMConfig(False, True, True) |
| 350 | + b = ScaledMMConfig(True, False, False) |
| 351 | + with pytest.raises( |
| 352 | + AssertionError, |
| 353 | + match="Both mm_configs must have the same emulate value, but got False and True", |
| 354 | + ): |
| 355 | + merge_mm_configs(a, b) |
| 356 | + a = ScaledMMConfig(False, True, True) |
| 357 | + b = ScaledMMConfig(False, False, False) |
| 358 | + c = merge_mm_configs(a, b) |
| 359 | + assert c.emulate is False |
| 360 | + assert c.use_fast_accum is False |
| 361 | + assert c.fp8_output is False |
| 362 | + |
| 363 | + a = ScaledMMConfig(False, True, False) |
| 364 | + b = ScaledMMConfig(False, True, False) |
| 365 | + c = merge_mm_configs(a, b) |
| 366 | + assert c.emulate is False |
| 367 | + assert c.use_fast_accum is True |
| 368 | + assert c.fp8_output is False |
| 369 | + |
329 | 370 |
|
330 | 371 | class TestNumerics: |
331 | 372 | @pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) |
|
0 commit comments