diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py index 689a9d3ec9..fcf75dba1d 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py @@ -41,6 +41,7 @@ def _kernel_quantize_mx4_unpack( FP4_EXP_BIAS: tl.constexpr, GROUP_LOAD: tl.constexpr, USE_INT64: tl.constexpr, + SCALE_K: tl.constexpr, ) -> None: """Quantize a 1D float tensor into a packed MX4 tensor. @@ -64,21 +65,8 @@ def _kernel_quantize_mx4_unpack( USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors. """ # Define Constant Expressions. - FP16_EXP_MASK: tl.constexpr = 0x7F80 # type: ignore[Incompatible variable type] - FP16_EXP_OFFSET: tl.constexpr = 7 # type: ignore[Incompatible variable type] FP16_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type] - FP16_SIGN_OFFSET: tl.constexpr = 15 # type: ignore[Incompatible variable type] - SIGN_MASK: tl.constexpr = 0x1 # type: ignore[Incompatible variable type] - FP16_MANTISSA_MASK: tl.constexpr = 0x007F # type: ignore[Incompatible variable type] - # FP4 has 2 mantissa bits, one explicit one implicit. - MBITS_IMPLICIT: tl.constexpr = MBITS + 1 # type: ignore[Incompatible variable type] - MAX_FP16_MANTISSA_BITS: tl.constexpr = 8 # type: ignore[Incompatible variable type] - IMPLIED_1_BIT: tl.constexpr = 1 << 7 # type: ignore[Incompatible variable type] BF16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type] - MANTISSA_OVERFLOW_THRESHOLD: tl.constexpr = (1 << MBITS_IMPLICIT) - 1 # type: ignore[Incompatible variable type] - EXPONENT_OVERFLOW_THRESHOLD: tl.constexpr = (1 << EBITS) - 1 # type: ignore[Incompatible variable type] - IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1)) - 1 - RAND_MASK: tl.constexpr = (1 << (FP16_EXP_OFFSET - MBITS)) - 1 # type: ignore[Incompatible variable type] # Get the current thread number. pid = tl.program_id(0) @@ -91,9 +79,9 @@ def _kernel_quantize_mx4_unpack( # Boundaries for writing to output tensor. NUM_GROUPS = M * GROUPS_PER_ROW - OUTPUT_CHUNK_SIZE = (GROUPS_PER_THREAD * GROUP_SIZE) // 2 + OUTPUT_CHUNK_SIZE = (GROUPS_PER_THREAD * GROUP_SIZE) // 8 SCALE_CHUNK_SIZE = GROUPS_PER_THREAD - OUTPUT_SIZE = (GROUP_SIZE * NUM_GROUPS) // 2 + OUTPUT_SIZE = (GROUP_SIZE * NUM_GROUPS) // 8 SCALE_SIZE = NUM_GROUPS # Find starting offsets for this thread. These are calculated before adjusting for padding. @@ -102,13 +90,8 @@ def _kernel_quantize_mx4_unpack( exp_start = pid * SCALE_CHUNK_SIZE # Initiate offset ranges used in kernel. input_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE) + input_start - output_offset = tl.arange(0, GROUP_LOAD * (GROUP_SIZE // 2)) + output_start - # Stochastic rounding loads chunks of random values. - if ROUNDING_MODE == 3: - rand_bits_offset = tl.arange(0, GROUP_LOAD) + pid * GROUPS_PER_THREAD - # Ceil rounding uses single values as a seed. - else: - rand_bits_offset = pid * GROUPS_PER_THREAD + output_offset = tl.arange(0, GROUP_LOAD * (GROUP_SIZE // 8)) + output_start + exp_offset = tl.arange(0, GROUP_LOAD) + exp_start # We need to shift output offsets to make space for shared exponent storage. # Now create offsets for writing the shared exponent. exp_offset = tl.arange(0, GROUP_LOAD) + exp_start @@ -141,7 +124,7 @@ def _kernel_quantize_mx4_unpack( ############## # View the block in terms of groups. - a_groups = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) + a_groups = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]).to(tl.float32) # Compute the shared exponent of each group. group_max = tl.max(tl.abs(a_groups), axis=1) # Prevent infinite values in log. @@ -149,13 +132,6 @@ def _kernel_quantize_mx4_unpack( # Load relevant random values if doing stochastic rounding # or stochastic casting. group_rand_bits = None - if (ROUNDING_MODE) == 3 or STOCHASTIC_CASTING: - group_rand_bits = tl.load( - rand_bits + rand_bits_offset, - mask=rand_bits_offset < K // GROUP_SIZE, - other=0, - ) - rand_bits_offset += GROUP_LOAD # Compute shared exponent using specified rounding mode. group_exp = _compute_exp(group_max, ROUNDING_MODE, group_rand_bits, MBITS) # Subtract largest exponent in target datatype and remove bias. @@ -165,134 +141,83 @@ def _kernel_quantize_mx4_unpack( # Next we scale A in preparation for quantization. # TODO: We convert to float16 rather than bf16 due to numerical accuracy, but we might need to consider fp32 - scale_ = tl.exp2(group_exp.to(tl.float64)).to(tl.float16) + scale_ = tl.exp2(group_exp.to(tl.float64)).to(tl.float32) # Apply scale_ to input. We do this by broadcasting scale. - scaled_a = ( - tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) - / tl.reshape(scale_, [GROUP_LOAD, 1]) - ).to(tl.bfloat16) + scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) / tl.reshape( + scale_, [GROUP_LOAD, 1] + ) # Reshape back to a flat array. scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE]) + temp_l, temp_r = tl.split( + tl.reshape(scaled_a, [(GROUP_LOAD * GROUP_SIZE) // 2, 2]) + ) # 0, 2, 4, 6, 8 || 1, 3, 5, 7, 9 + t_one, t_two = tl.split( + tl.reshape(temp_l, [(GROUP_LOAD * GROUP_SIZE) // 4, 2]) + ) # 0 4 8 || 2, 6, 10 + t_three, t_four = tl.split( + tl.reshape(temp_r, [(GROUP_LOAD * GROUP_SIZE) // 4, 2]) + ) # 1, 5, 9 || 3, 7, 11 + + f_one, f_two = tl.split( + tl.reshape(t_one, [(GROUP_LOAD * GROUP_SIZE) // 8, 2]) + ) # 0, 8 || 4, 12 + f_three, f_four = tl.split( + tl.reshape(t_two, [(GROUP_LOAD * GROUP_SIZE) // 8, 2]) + ) # 2, 10 || 6, 14 + f_five, f_six = tl.split( + tl.reshape(t_three, [(GROUP_LOAD * GROUP_SIZE) // 8, 2]) + ) # 1, 9 || 5, 13 + f_seven, f_eight = tl.split( + tl.reshape(t_four, [(GROUP_LOAD * GROUP_SIZE) // 8, 2]) + ) # 3, 11 || 7, 15 + packed_result = tl.inline_asm_elementwise( + asm=""" + { + .reg .b8 byte0; + .reg .b8 byte1; + .reg .b8 byte2; + .reg .b8 byte3; + cvt.rn.satfinite.e2m1x2.f32 byte0, $2, $1; + cvt.rn.satfinite.e2m1x2.f32 byte1, $4, $3; + cvt.rn.satfinite.e2m1x2.f32 byte2, $6, $5; + cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $7; + mov.b32 $0, {byte0, byte1, byte2, byte3}; + + } + """, + constraints="=r," "f, f, f, f, f, f, f, f", + args=[f_one, f_five, f_three, f_seven, f_two, f_six, f_four, f_eight], + dtype=tl.int32, + is_pure=True, + pack=1, + ) + + n_col_blocks = SCALE_K // 4 + first_dim = exp_offset // (512 * n_col_blocks) + second_dim = (exp_offset % (512 * n_col_blocks)) // (128 * n_col_blocks) + third_dim = (exp_offset % (128 * n_col_blocks)) // (4 * n_col_blocks) + fourth_dim = (exp_offset % (4 * n_col_blocks)) // 4 + fifth_dim = exp_offset % 4 + actual_offset = ( + first_dim * (512 * n_col_blocks) + + fourth_dim * (512) + + third_dim * (16) + + second_dim * (4) + + fifth_dim + ) # We're done with group_exp now so we can write it out. - # We readd fp16_exp_bias for compatibility with cuda dequant. tl.store( - scale + exp_offset, + scale + actual_offset, (group_exp + FP16_EXP_BIAS).to(tl.int8), # Prevent writing outside this chunk or the main array. mask=(exp_offset < SCALE_SIZE) & (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1))), ) - - # Quantization step - ################### - - # During quantization, we're going to be doing a lot of bitwise operations. - # This is easier to work with in int32. - scaled_a = scaled_a.to(tl.int16, bitcast=True) - - # When doing stochastic downcasting, generate random values for this block - # and apply it to the mantissa. - if STOCHASTIC_CASTING: - # We're going to generate 4 blocks at once so we only need - # one fourth of the input offsets. - # Start by splitting down to half of offsets. - philox_4x_offset = tl.split( - tl.reshape( - input_offset, - [GROUP_LOAD * GROUP_SIZE // 2, 2], - can_reorder=True, - ) - ) - # Split down to fourth. - philox_4x_offset = tl.split( - tl.reshape( - philox_4x_offset, - [GROUP_LOAD * GROUP_SIZE // 4, 2], - can_reorder=True, - ) - ) - # Generate 4 blocks of random bits for this block. - a_4x, b_4x, c_4x, d_4x = tl.randint4x( - group_rand_bits, philox_4x_offset, n_rounds=7 - ) - # Combine the 4 blocks into a single chunk of random values. - # This needs to be done incrementally. - stochastic_round_bits = tl.join(tl.join(a_4x, b_4x), tl.join(c_4x, d_4x)) - # Flatten back to simple array. - stochastic_round_bits = tl.reshape( - stochastic_round_bits, [GROUP_LOAD * GROUP_SIZE] - ).to(tl.int16, bitcast=True) - - # Mask off mantissa bits of random value and add to mantissa. - scaled_a = scaled_a + (stochastic_round_bits & RAND_MASK) - - # Extract sign bit of value. - sign_bit = (scaled_a >> FP16_SIGN_OFFSET) & SIGN_MASK - - # Extract exponent. - biased_exp = (scaled_a & FP16_EXP_MASK) >> FP16_EXP_OFFSET - - # Extract mantissa. - trailing_mantissa = scaled_a & FP16_MANTISSA_MASK - - # Adjust exponent bias for FP4. - new_biased_exp = biased_exp - FP16_EXP_BIAS + FP4_EXP_BIAS - - # Compute difference between ideal exponent and what fp4 can represent. - exp_diff = tl.where(new_biased_exp <= 0, 1 - new_biased_exp, 0) - - # Clip this difference to maximum number of fp32 mantissa bits. - exp_diff = tl.minimum(exp_diff, MAX_FP16_MANTISSA_BITS) - - # Now we round our fp32 mantissa down to fp4. - is_subnorm = biased_exp == 0 - # Add implied 1 bit to normal values. - mantissa = tl.where( - is_subnorm, trailing_mantissa, trailing_mantissa + IMPLIED_1_BIT - ) - # Compute base number of bits corresponding to the mantissa, smaller for subnorms - # since implied one is included in exp_diff. - fp16_sig_bits = tl.where(is_subnorm, 7, 8).to(tl.int32) - # Now we're ready to shift down to target bitwidth (with an extra bit for rounding). - mantissa = mantissa >> (fp16_sig_bits + exp_diff - MBITS_IMPLICIT - 1) - # Perform rounding by adding 1 and shifting down. - mantissa = (mantissa + 1) >> 1 - - # Check for overflow and adjust exponent accordingly. - overflow = mantissa > MANTISSA_OVERFLOW_THRESHOLD - # Allow subnorms to overflow into normals, otherwise shift away overflow. - mantissa = tl.where(overflow and (not is_subnorm), mantissa >> 1, mantissa) - # Special case where a value is subnormal and has a large mantissa, overflow it. - new_biased_exp = tl.where( - (new_biased_exp <= 0) and (mantissa == 2), 1, new_biased_exp - ) - # Remove implicit 1. - mantissa = mantissa & IMPLICIT_1_MASK - # Add overflow to exponent. - new_biased_exp = tl.where(overflow, new_biased_exp + 1, new_biased_exp) - # If exp overflows, set mantissa to maximum value (equivalent to clamping). - mantissa = tl.where(new_biased_exp > EXPONENT_OVERFLOW_THRESHOLD, 1, mantissa) - - # Construct FP4 value from components. - new_biased_exp = tl.maximum( - tl.minimum(new_biased_exp, EXPONENT_OVERFLOW_THRESHOLD), 0 - ) - - mx4_value = (new_biased_exp << (MBITS_IMPLICIT - 1)) | mantissa - mx4_value = (sign_bit << (EBITS + MBITS)) | mx4_value - - # Extract low and high bits from values. - low_mx4, high_mx4 = tl.split( - tl.reshape(mx4_value, [(GROUP_LOAD * GROUP_SIZE) // 2, 2]) - ) - # Shift mx4 values together so they are packed into int8. - packed_mx4 = ((high_mx4 << 4) | (low_mx4)).to(tl.int8) - # Write out packed values to output tensor. tl.store( out + output_offset, - packed_mx4, + packed_result, # Prevent writing outside this chunk or the main array. mask=(output_offset < OUTPUT_SIZE) & (output_offset < (OUTPUT_CHUNK_SIZE * (pid + 1))), @@ -301,7 +226,7 @@ def _kernel_quantize_mx4_unpack( # Update offsets so we work on the next block. input_offset += GROUP_LOAD * GROUP_SIZE exp_offset += GROUP_LOAD - output_offset += GROUP_LOAD * GROUP_SIZE // 2 + output_offset += GROUP_LOAD * GROUP_SIZE // 8 def _to_blocked(x: torch.Tensor) -> torch.Tensor: @@ -344,7 +269,7 @@ def ceil_div(a: int, b: int) -> int: def triton_quantize_mx4_unpack( - a: torch.Tensor, + input: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1, @@ -368,23 +293,41 @@ def triton_quantize_mx4_unpack( torch.Tensor: [M / group_size] mx4 shared exponents into int8 eg. - Input with shape [1, 8192] will be quantized to [1, 4096 + 256] as + Input with shape [1, 8192] will be quantized to [1, 4096 + 512] as each value contain two elements packed into an int8 and - there are 32 groups in each row. + there are 32 elements per group. """ - # If given an empty shape, return an empty tensor. - if a.numel() == 0: - return torch.empty(a.shape, device=a.device, dtype=torch.uint8), torch.empty( - a.shape, device=a.device, dtype=torch.uint8 - ) - # Make sure input is continuous in memory. - assert a.is_contiguous(), "Inputs to mx4 quantize must be contiguous in memory." - orig_shape = a.shape - # For simplicity, view input as a 2D array. - a = a.view(-1, a.shape[-1]) - # Extract rows and columns. - M, K = a.shape + orig_shape = input.shape + assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." + other_dims = 1 if input.ndim == 1 else -1 + input = input.reshape(other_dims, input.shape[-1]) + M, K = input.shape + block_size = group_size + device = input.device + + assert K % block_size == 0, f"last dim has to be multiple of 16, but got {K}." + assert input.dtype in ( + torch.float16, + torch.bfloat16, + ), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." + + # Two fp4 values will be packed into an uint8. + out = torch.empty((M, K // 8), device=device, dtype=torch.uint32) + + # We use the rounded values to store the swizzled values. Due to the + # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. + # So, we first pad the scales to multiples of 128 and 4. Then, the scales + # (in float8_e4m3fn) int8. More: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x + def round_up(x: int, y: int) -> int: + return (x + y - 1) // y * y + + rounded_M = round_up(M, 128) + scale_K = K // block_size + rounded_K = round_up(scale_K, 4) + scale = torch.empty((rounded_M, rounded_K), device=device, dtype=torch.int8) + # In this kernel, we want each row to be divisible by group_size. # If the rows are not, then we will pad them. Find the number of # groups per row after padding. @@ -392,7 +335,7 @@ def triton_quantize_mx4_unpack( num_groups = M * groups_per_row # Find how many groups each thread should process. We do this # by assuming that it is good to distribute work evenly over threads. - num_threads = math.ceil(math.sqrt(a.numel())) + num_threads = math.ceil(math.sqrt(input.numel())) # Data is loaded in chunks of GROUP_LOAD elements, so theres no reason # to ever fewer groups per thread than it. GROUP_LOAD = 64 @@ -403,12 +346,6 @@ def triton_quantize_mx4_unpack( else: padding = 0 - # Create output tensor. - out_elems = (num_groups * group_size) // 2 - scale_elems = num_groups - out = torch.empty([out_elems], device=a.device, dtype=torch.uint8) - scale = torch.empty([scale_elems], device=a.device, dtype=torch.uint8) - # If using stochastic rounding, create random noise for each group. # We use the same random bits as seeds when doing stochastic downcasting. if rounding_mode == RoundingMode.stochastic or stochastic_casting: @@ -418,7 +355,7 @@ def triton_quantize_mx4_unpack( high=2**31 - 1, size=(num_groups,), dtype=torch.int32, - device=a.device, + device=input.device, ) else: rand_bits = None @@ -426,9 +363,10 @@ def triton_quantize_mx4_unpack( # Check if we need to use int64 for indexing. use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1 # Invoke triton quantization kernel over rows. + grid = (num_threads,) _kernel_quantize_mx4_unpack[grid]( - a, + input, out, scale, rand_bits=rand_bits, @@ -452,12 +390,12 @@ def triton_quantize_mx4_unpack( GROUP_LOAD=GROUP_LOAD, # pyre-ignore[6] USE_INT64=use_int64, + # pyre-ignore[6] + SCALE_K=rounded_K, ) - scale = scale.view(torch.float8_e8m0fnu) - scale = scale.view(orig_shape[0], -1) - scale = _to_blocked(scale) - return out.view(list(orig_shape[:-1]) + [-1]), scale + scale = scale.flatten() + return out.view(list(orig_shape[:-1]) + [-1]).view(torch.uint8), scale @triton.jit diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 3cc0dfbc09..6eba961fb9 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -4359,3 +4359,74 @@ def grid(meta): xq, x_scale, x_dequant, M, K, BLOCK_M=block_m, BLOCK_K=block_k # pyre-ignore[6] ) return x_dequant + + +# This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142 +def to_mxfp8( + data_hp: torch.Tensor, + block_size: int = 32, +): + assert data_hp.dtype in ( + torch.bfloat16, + torch.float, + ), f"{data_hp.dtype} is not supported yet" + assert ( + data_hp.shape[-1] % block_size == 0 + ), f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}" + assert data_hp.is_contiguous(), "unsupported" + + orig_shape = data_hp.shape + data_hp = data_hp.reshape( + *orig_shape[:-1], orig_shape[-1] // block_size, block_size + ) + + max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1) + + data_hp = data_hp.to(torch.float32) + max_abs = max_abs.to(torch.float32) + + F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 + max_pos = F8E4M3_MAX + + # RCEIL + def _to_mx_rceil( + data_hp: torch.Tensor, + max_abs: torch.Tensor, + max_pos: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + E8M0_EXPONENT_BIAS = 127 + descale = max_abs / max_pos + exponent = torch.where( + torch.isnan(descale), + 0xFF, # Handle biased exponent for nan + # NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping + ( + torch.clamp( + torch.ceil(torch.log2(descale)), + min=-E8M0_EXPONENT_BIAS, + max=E8M0_EXPONENT_BIAS, + ) + + E8M0_EXPONENT_BIAS + ).to(torch.uint8), + ) + + descale_fp = torch.where( + exponent == 0, + 1.0, + torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)), + ) + + # scale and saturated cast the data elements to max of target dtype + data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos) + return exponent, data_lp + + scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) + + # cast to target dtype + data_lp = data_lp.to(torch.float8_e4m3fn) + # need to reshape at the end to help inductor fuse things + data_lp = data_lp.reshape(orig_shape) + + scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) + scale_e8m0_biased = scale_e8m0_biased.squeeze(-1) + return scale_e8m0_biased, data_lp diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index 5ff8309102..6efe1f9af7 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -15,6 +15,7 @@ import triton # @manual=//triton:triton from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import ( + _to_blocked, calculate_group_max, mega_fp4_pack, mega_fp4_quantize_kernel, @@ -33,6 +34,7 @@ quantize_fp8_group, quantize_fp8_row, scale_fp8_row, + to_mxfp8, triton_quantize_fp8_row, ) from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( @@ -2497,20 +2499,53 @@ def preprocess(self, x, w): return x, wq, w_scale, m_sizes def quantize(self, x, wq, w_scale, m_sizes): - xq, x_scale = zip(*[triton_quantize_mx4_unpack(i) for i in x]) - xq = torch.stack(xq, dim=0).contiguous() - x_scale = torch.stack(x_scale, dim=0).contiguous() + starting_row_after_padding_list = [0] + xq_list = [] + x_scale_list = [] + for i in range(m_sizes.shape[0]): + scale_slice = x[i] + if m_sizes[i].item() != 0: + xq, x_scale = triton_quantize_mx4_unpack(scale_slice) + xq_list.append(xq) + x_scale_list.append(x_scale) + starting_row_after_padding_list.append( + starting_row_after_padding_list[i] + + x_scale.numel() // (x[0].shape[1] // 32) + ) + else: + starting_row_after_padding_list.append( + starting_row_after_padding_list[i] + ) + xq = torch.cat(xq_list, dim=0).contiguous() + x_scale = torch.cat(x_scale_list, dim=0).contiguous() + x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32) xq = xq.view(-1, xq.shape[-1]) - return xq, wq, x_scale, w_scale, m_sizes + return ( + xq, + wq, + x_scale, + w_scale, + m_sizes, + torch.tensor(starting_row_after_padding_list, device=xq.device), + ) - def compute(self, xq, wq, x_scale, w_scale, m_sizes): + def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding): return torch.ops.fbgemm.f4f4bf16_grouped_stacked( - xq, wq, x_scale, w_scale, m_sizes + xq, + wq, + x_scale, + w_scale, + m_sizes, + starting_row_after_padding=starting_row_after_padding, ) def quantize_and_compute(self, x, w): - xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, w) - return self.compute(xq, wq, x_scale, w_scale, m_sizes) + xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize( + x, w + ) + return self.compute( + xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding + ) @property def name(self) -> str: @@ -2835,3 +2870,86 @@ def hip(self) -> bool: @property def cuda(self) -> bool: return True + + +@register_quantize_op +class MXFP8StackedGroupedGemm(QuantizeOpBase): + """ + MXFP8 grouped matmul with blockwise scaling and stacked inputs. + """ + + def preprocess(self, x, w): + m_values = [i.shape[0] for i in x] + m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device) + wq_list = [] + w_scale_list = [] + for i in range(m_sizes.shape[0]): + w_scale, wq = to_mxfp8(w[i]) + w_scale = _to_blocked(w_scale) + wq_list.append(wq) + w_scale_list.append(w_scale) + wq = torch.stack(wq_list, dim=0).contiguous() + w_scale = torch.stack(w_scale_list, dim=0).contiguous() + return x, wq, w_scale, m_sizes + + def quantize(self, x, wq, w_scale, m_sizes): + starting_row_after_padding_list = [0] + xq_list = [] + x_scale_list = [] + for i in range(m_sizes.shape[0]): + scale_slice = x[i] + if m_sizes[i].item() != 0: + x_scale, xq = to_mxfp8(scale_slice) + x_scale = _to_blocked(x_scale) + xq_list.append(xq) + x_scale_list.append(x_scale) + starting_row_after_padding_list.append( + starting_row_after_padding_list[i] + + x_scale.numel() // (x[0].shape[1] // 32) + ) + else: + starting_row_after_padding_list.append( + starting_row_after_padding_list[i] + ) + xq = torch.cat(xq_list, dim=0).contiguous() + x_scale = torch.cat(x_scale_list, dim=0).contiguous() + x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32) + xq = xq.view(-1, xq.shape[-1]) + return ( + xq, + wq, + x_scale, + w_scale, + m_sizes, + torch.tensor(starting_row_after_padding_list, device=xq.device), + ) + + def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding): + return torch.ops.fbgemm.mx8mx8bf16_grouped_stacked( + xq, + wq, + x_scale, + w_scale, + m_sizes, + starting_row_after_padding=starting_row_after_padding, + ) + + def quantize_and_compute(self, x, w): + xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize( + x, w + ) + return self.compute( + xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding + ) + + @property + def name(self) -> str: + return "cutlass_mx8mx8bf16_grouped_stacked" + + @property + def hip(self) -> bool: + return False + + @property + def cuda(self) -> bool: + return True diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_common.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_common.cuh index a472bb73dd..262b290e13 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_common.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_common.cuh @@ -160,15 +160,21 @@ __global__ void set_stacked_kernel_args_kernel( int64_t offset_M = 0; int64_t accumulated_x_scale = 0; int64_t accumulated_w_scale = 0; + int ele_per_quantize_group = 16; + if (global_scale == nullptr) { + ele_per_quantize_group = 32; + } for (int i = 0; i < group_index; i++) { offset_M += M_sizes[i]; /* It's calculated this way since the scales are at least padded to multiples of (128, 4), and there is a group of 16 elements per scale. */ accumulated_w_scale += - (((N + 128 - 1) / 128) * 128 * ((K + 4 - 1) / 4) * 4 / 16); + (((N + 128 - 1) / 128) * 128 * ((K + 4 - 1) / 4) * 4 / + ele_per_quantize_group); } - accumulated_x_scale = starting_row_after_padding[group_index] * K / 16; + accumulated_x_scale = + starting_row_after_padding[group_index] * K / ele_per_quantize_group; // Set the problem shape for this group. problem_shape_ptr[non_zero_idx] = ProblemShape(N, M, K); // Set input pointers. @@ -646,7 +652,7 @@ at::Tensor f4f4bf16_grouped_impl( layout_SFB, nullptr, nullptr, - nullptr); + starting_row_after_padding_ptr); } // Set the number of groups to the kernel to be at most the number of // non-zero rows. diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu new file mode 100644 index 0000000000..693c05739a --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu @@ -0,0 +1,185 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) +#include "mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh" +#endif + +namespace fbgemm_gpu { + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) + +template +Kernel_mx8mx8bf16_grouped +get_kernel_via_heuristics(int total_M, int N, int K, int G) { + // Llama4 shapes + if (N == 5120 && K == 1024) { + if (G <= 8) { + if (total_M <= 256) { + return mx8mx8bf16_grouped_256_64_256_2_1_1; + } else if (total_M <= 512) { + return mx8mx8bf16_grouped_128_64_256_1_1_1; + } else if (total_M <= 1024) { + return mx8mx8bf16_grouped_128_128_256_1_1_1; + } + } else if (G <= 16) { + if (total_M <= 1024) { + return mx8mx8bf16_grouped_128_64_256_1_1_1; + } else if (total_M <= 2048) { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } + } else { + if (total_M <= 1024) { + return mx8mx8bf16_grouped_256_64_256_2_1_1; + } else if (total_M <= 4096) { + return mx8mx8bf16_grouped_128_64_256_1_1_1; + } else if (total_M <= 8192) { + return mx8mx8bf16_grouped_256_64_256_2_1_1; + } + } + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else if (N == 2048 && K == 5120) { + if (G <= 8) { + if (total_M <= 256) { + return mx8mx8bf16_grouped_256_64_256_2_1_1; + } else if (total_M <= 512) { + return mx8mx8bf16_grouped_128_64_256_1_1_1; + } else if (total_M <= 1024) { + return mx8mx8bf16_grouped_128_128_256_1_1_1; + } + } else if (G <= 16) { + if (total_M <= 1024) { + return mx8mx8bf16_grouped_256_64_256_2_1_1; + } else if (total_M <= 2048) { + return mx8mx8bf16_grouped_128_128_256_1_1_1; + } + } else { + if (total_M <= 1024) { + return mx8mx8bf16_grouped_256_64_256_2_1_1; + } else if (total_M <= 16384) { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } + } + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } + + // Fallback to legacy heuristic + if (total_M <= 1000) { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } else { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } +} + +template +at::Tensor dispatch_mx8_grouped_kernel( + int total_M, + int N, + int K, + int G, + InputType XQ, // FP8 + InputType WQ, // FP8 + InputType x_scale, + InputType w_scale, + at::Tensor output, + std::optional zero_start_index_M = std::nullopt, + std::optional M_sizes = std::nullopt, + std::optional starting_row_after_padding = std::nullopt) { + TORCH_CHECK( + zero_start_index_M.has_value() != M_sizes.has_value(), + "One of zero_start_index_M or M_sizes must be provided."); + TORCH_CHECK(M_sizes.has_value(), "M_sizes is assumed to be provided."); + TORCH_CHECK( + starting_row_after_padding.has_value(), + "starting_row_after_padding is assumed to be provided."); + at::Tensor starting_row_after_padding_actual = + starting_row_after_padding.value_or(at::zeros({0})); + TORCH_CHECK(starting_row_after_padding_actual.size(0) % (G + 1) == 0); + + // Select kernel to run via heuristics. + auto kernel = [&]() { + return get_kernel_via_heuristics(total_M, N, K, G); + }(); + // Invoke kernel + return kernel( + XQ, + WQ, + x_scale, + w_scale, + output, + G, + zero_start_index_M, + M_sizes, + starting_row_after_padding); +} + +at::Tensor mx8mx8bf16_grouped_stacked( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor M_sizes, + std::optional starting_row_after_padding = std::nullopt) { + int64_t total_M = XQ.size(0); + int64_t N = WQ.size(1); + int64_t K = WQ.size(2); + int64_t G = M_sizes.size(0); + TORCH_CHECK( + M_sizes.device() == XQ.device(), + "M_sizes must be on same device as inputs."); + TORCH_CHECK( + WQ.dim() == 3 && WQ.size(0) == G, "Weights should be shape [G, N, K].") + at::Tensor Y = at::empty({total_M, N}, XQ.options().dtype(at::kBFloat16)); + // Early exit for empty inputs. + if (total_M == 0) { + return Y; + } + // Return continuous view of output. + return dispatch_mx8_grouped_kernel( + total_M, + N, + K, + G, + XQ, + WQ, + x_scale, + w_scale, + Y, + std::nullopt, + M_sizes, + starting_row_after_padding); +} + +#else + +at::Tensor mx8mx8bf16_grouped_stacked( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor M_sizes, + std::optional starting_row_after_padding = std::nullopt) { + throw std::runtime_error( + "CUDA version is older than 12.8"); // requires CUDA>=12.8 +} +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_128_256_1_1_1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_128_256_1_1_1.cu new file mode 100644 index 0000000000..778414b91a --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_128_256_1_1_1.cu @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "mx8mx8bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) + +at::Tensor mx8mx8bf16_grouped_128_128_256_1_1_1( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding) { + return mx8mx8bf16_grouped_impl( + XQ, + WQ, + x_scale, + w_scale, + output, + G, + zero_start_index_M, + M_sizes, + starting_row_after_padding); +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_64_256_1_1_1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_64_256_1_1_1.cu new file mode 100644 index 0000000000..ae8b8a5e08 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_64_256_1_1_1.cu @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "mx8mx8bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) + +at::Tensor mx8mx8bf16_grouped_128_64_256_1_1_1( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding) { + return mx8mx8bf16_grouped_impl( + XQ, + WQ, + x_scale, + w_scale, + output, + G, + zero_start_index_M, + M_sizes, + starting_row_after_padding); +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_128_256_2_1_1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_128_256_2_1_1.cu new file mode 100644 index 0000000000..7142ef01c5 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_128_256_2_1_1.cu @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "mx8mx8bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) + +at::Tensor mx8mx8bf16_grouped_256_128_256_2_1_1( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding) { + return mx8mx8bf16_grouped_impl( + XQ, + WQ, + x_scale, + w_scale, + output, + G, + zero_start_index_M, + M_sizes, + starting_row_after_padding); +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_256_256_2_1_1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_256_256_2_1_1.cu new file mode 100644 index 0000000000..f9e4444603 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_256_256_2_1_1.cu @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "mx8mx8bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) + +at::Tensor mx8mx8bf16_grouped_256_256_256_2_1_1( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding) { + return mx8mx8bf16_grouped_impl( + XQ, + WQ, + x_scale, + w_scale, + output, + G, + zero_start_index_M, + M_sizes, + starting_row_after_padding); +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_64_256_2_1_1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_64_256_2_1_1.cu new file mode 100644 index 0000000000..601e0904ff --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_64_256_2_1_1.cu @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "mx8mx8bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) + +at::Tensor mx8mx8bf16_grouped_256_64_256_2_1_1( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding) { + return mx8mx8bf16_grouped_impl( + XQ, + WQ, + x_scale, + w_scale, + output, + G, + zero_start_index_M, + M_sizes, + starting_row_after_padding); +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_common.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_common.cuh new file mode 100644 index 0000000000..8e2cc7f008 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_common.cuh @@ -0,0 +1,528 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off + // The fixed ordering of the headers is required for CUTLASS 3.2+ + #include + #include // @manual + #include // @manual + #include // @manual +// clang-format on + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) + +inline int64_t _byte_align(int64_t offset) { + int64_t remainder = offset % 16; + if (remainder != 0) { + offset += (16 - remainder); + } + return offset; +} + +template < + typename ProblemShape, + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementComputeEpilogue, + typename StrideA, + typename StrideB, + typename StrideC, + typename LayoutSFA, + typename LayoutSFB, + typename ElementGlobalScale, + typename Sm1xxBlkScaledConfig> +__global__ void set_kernel_args_kernel( + int i, // Group index + int64_t G, // Total groups. + int64_t M, + int64_t N, + int64_t K, + ProblemShape* problem_shape_ptr, + ElementA* xq, + const ElementA** xq_ptr, + ElementB* wq, + const ElementB** wq_ptr, + ElementComputeEpilogue* x_scale, + const ElementComputeEpilogue** x_scale_ptr, + ElementComputeEpilogue* w_scale, + const ElementComputeEpilogue** w_scale_ptr, + ElementC* output, + ElementC** output_ptr, + StrideA* stride_a_ptr, + StrideB* stride_b_ptr, + StrideC* stride_c_ptr, + LayoutSFA* layout_SFA, + LayoutSFB* layout_SFB) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + // Each kernel annoyingly can only set the kernel args for one group. + // This could only be avoided with complicated memory management. + if (idx == 0) { + problem_shape_ptr[i] = ProblemShape(N, M, K); + xq_ptr[i] = xq; + wq_ptr[i] = wq; + x_scale_ptr[i] = x_scale; + w_scale_ptr[i] = w_scale; + output_ptr[i] = output; + stride_a_ptr[i] = cutlass::make_cute_packed_stride( + StrideA{}, cute::make_shape(int(M), int(K), 1)); + stride_b_ptr[i] = cutlass::make_cute_packed_stride( + StrideB{}, cute::make_shape(int(N), int(K), 1)); + stride_c_ptr[i] = cutlass::make_cute_packed_stride( + StrideC{}, cute::make_shape(int(N), int(M), 1)); + layout_SFA[i] = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( + cute::make_shape(int(M), int(N), int(K), 1)); + layout_SFB[i] = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( + cute::make_shape(int(M), int(N), int(K), 1)); + } +} + +template < + typename ProblemShape, + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementComputeEpilogue, + typename StrideA, + typename StrideB, + typename StrideC, + typename LayoutSFA, + typename LayoutSFB, + typename ElementGlobalScale, + typename Sm1xxBlkScaledConfig> +__global__ void set_stacked_kernel_args_kernel( + int64_t G, + int64_t N, + int64_t K, + int64_t num_x_scale_per_group, + int64_t num_w_scale_per_group, + ProblemShape* problem_shape_ptr, + ElementA* xq, + const ElementA** xq_ptr, + ElementB* wq, + const ElementB** wq_ptr, + ElementComputeEpilogue* x_scale, + const ElementComputeEpilogue** x_scale_ptr, + ElementComputeEpilogue* w_scale, + const ElementComputeEpilogue** w_scale_ptr, + ElementC* output, + ElementC** output_ptr, + StrideA* stride_a_ptr, + StrideB* stride_b_ptr, + StrideC* stride_c_ptr, + int64_t* M_sizes, + LayoutSFA* layout_SFA, + LayoutSFB* layout_SFB, + int64_t* starting_row_after_padding) { + uint32_t group_index = blockIdx.x * blockDim.x + threadIdx.x; + // If this thread corresponds to a valid group, write kernel args to device + // memory. + if (group_index < G) { + // Its possible that we're only writing a subset of the groups to + // kernel args. To do this, we need to set all groups initially to empty. + // and keep a problem counter for the number of non-empty groups. + __shared__ int non_zero_counter; + // Initialize counter in first group. + if (group_index == 0) { + non_zero_counter = 0; + } + // Set problem shapes to empty by default. + problem_shape_ptr[group_index] = ProblemShape(0, 0, 0); + // Sync threads to get consistent state in the block. + __syncthreads(); + + // Compute shape for this group. + // M for this group is pulled directly from M_sizes. + int M = M_sizes[group_index]; + // Only proceed to writing kernel args if this group is non-empty. + if (M > 0) { + // Get the index for this group atomically. + int non_zero_idx = atomicAdd(&non_zero_counter, 1); + // We compute the offset by getting the cumulative sum over + // prior groups. + int64_t offset_M = 0; + int64_t accumulated_x_scale = 0; + int64_t accumulated_w_scale = 0; + for (int i = 0; i < group_index; i++) { + offset_M += M_sizes[i]; + /* It's calculated this way since the scales are at least padded to + multiples of (128, 4), and there is a group of 32 elements per scale. + */ + accumulated_w_scale += + (((N + 128 - 1) / 128) * 128 * ((K + 4 - 1) / 4) * 4 / 32); + } + accumulated_x_scale = starting_row_after_padding[group_index] * K / 32; + // Set the problem shape for this group. + problem_shape_ptr[non_zero_idx] = ProblemShape(N, M, K); + // Set input pointers. + xq_ptr[non_zero_idx] = xq + (offset_M * K); + wq_ptr[non_zero_idx] = wq + (group_index * N * K); + x_scale_ptr[non_zero_idx] = x_scale + accumulated_x_scale; + w_scale_ptr[non_zero_idx] = w_scale + accumulated_w_scale; + output_ptr[non_zero_idx] = output + (offset_M * N); + stride_a_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( + StrideA{}, cute::make_shape(int(M), int(K), 1)); + stride_b_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( + StrideB{}, cute::make_shape(int(N), int(K), 1)); + stride_c_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( + StrideC{}, cute::make_shape(int(N), int(M), 1)); + layout_SFA[non_zero_idx] = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( + cute::make_shape(int(M), int(N), int(K), 1)); + layout_SFB[non_zero_idx] = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( + cute::make_shape(int(M), int(N), int(K), 1)); + } + } +} + +template < + typename InputType, + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K> +at::Tensor mx8mx8bf16_grouped_impl( + InputType XQ, // FP8 + InputType WQ, // FP8 + InputType x_scale, + InputType w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding) { + // The number of groups the kernel uses may vary. + int kernel_groups = G; + + at::TensorOptions options; + options = XQ.options(); + + // Return early if there are no elements in the output. + if (output.numel() == 0) { + return output; + } + + // Define gemm configuration. + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float8_t; + using ElementC = cutlass::bfloat16_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; + constexpr int AlignmentA = 32; + constexpr int AlignmentB = 32; + using ElementGlobalScale = float; + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using TileShape = + cute::Shape, cute::Int, cute::Int>; + using ClusterShape = + cute::Shape, cute::Int, cute::Int>; + + using KernelSchedule = cute::conditional_t< + (TB_M == 256) && (TBS_M % 2 == 0), + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100>; + using EpilogueSchedule = cute::conditional_t< + (TB_M == 256) && (TBS_M % 2 == 0), + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + // void, // Indicate there is no beta scaling to save register + // space. + ElementC, + typename cutlass::layout::LayoutTranspose::type*, + 128 / cutlass::sizeof_bits::value, + ElementC, + typename cutlass::layout::LayoutTranspose::type*, + 128 / cutlass::sizeof_bits::value, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + cutlass::arch::OpClassBlockScaledTensorOp, + ElementB, + LayoutB_Transpose*, + AlignmentB, + ElementA, + LayoutA_Transpose*, + AlignmentA, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel:: + GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideD; + + using ElementComputeEpilogue = typename ElementA::ScaleFactorType; + + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop:: + InternalLayoutSFA; // Scale Factor tensors have an interleaved layout. + // Bring Layout instead of stride. + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop:: + InternalLayoutSFB; // Scale Factor tensors have an interleaved layout. + // Bring Layout instead of stride. + + // Create a buffer for kernel arguments. We do this by first figuring out + // how much space each sub-argument requires and setting up corresponding + // pointers. + const int64_t problem_size_offset = 0; + int64_t problem_size_buffer = + _byte_align(G * sizeof(ProblemShape::UnderlyingProblemShape)); + + // Next create space for XQ pointers. + const int64_t xq_offset = problem_size_offset + problem_size_buffer; + int64_t xq_size_buffer = _byte_align(G * sizeof(ElementA**)); + + // WQ Pointers. + const int64_t wq_offset = xq_offset + xq_size_buffer; + int64_t wq_size_buffer = _byte_align(G * sizeof(ElementB**)); + + // X block scales. + const int64_t x_scale_offset = wq_offset + wq_size_buffer; + int64_t x_scale_buffer = _byte_align(G * sizeof(ElementComputeEpilogue**)); + + // W block scales. + const int64_t w_scale_offset = x_scale_offset + x_scale_buffer; + int64_t w_scale_buffer = _byte_align(G * sizeof(ElementComputeEpilogue**)); + + // Outputs. + const int64_t output_offset = w_scale_offset + w_scale_buffer; + int64_t output_buffer = _byte_align(G * sizeof(ElementC**)); + + // A stride. + const int64_t stride_a_offset = output_offset + output_buffer; + int64_t stride_a_buffer = _byte_align(G * sizeof(StrideA)); + + // B stride; + const int64_t stride_b_offset = stride_a_offset + stride_a_buffer; + int64_t stride_b_buffer = _byte_align(G * sizeof(StrideB)); + + // C stride; + const int64_t stride_c_offset = stride_b_offset + stride_b_buffer; + int64_t stride_c_buffer = _byte_align(G * sizeof(StrideC)); + + // SFA layout + const int64_t layout_SFA_offset = stride_c_offset + stride_c_buffer; + int64_t layout_SFA_buffer = _byte_align(G * sizeof(LayoutSFA)); + + // SFB layout + const int64_t layout_SFB_offset = layout_SFA_offset + layout_SFA_buffer; + int64_t layout_SFB_buffer = _byte_align(G * sizeof(LayoutSFB)); + + // Compute total buffer size + int64_t total_buffer_size = layout_SFB_offset + layout_SFB_buffer; + + // Allocate space for gemm information. + at::Tensor kernel_args = + at::empty({total_buffer_size}, options.dtype(at::kByte)); + + // Get byte pointer to underlying data. + char* kernel_args_ptr = reinterpret_cast(kernel_args.data_ptr()); + + // Now use offsets to get appropriately typed pointers. + ProblemShape::UnderlyingProblemShape* problem_shape_ptr = + reinterpret_cast( + kernel_args_ptr + problem_size_offset); + const ElementA** xq_ptr = + reinterpret_cast(kernel_args_ptr + xq_offset); + const ElementB** wq_ptr = + reinterpret_cast(kernel_args_ptr + wq_offset); + const ElementComputeEpilogue** x_scale_ptr = + reinterpret_cast( + kernel_args_ptr + x_scale_offset); + const ElementComputeEpilogue** w_scale_ptr = + reinterpret_cast( + kernel_args_ptr + w_scale_offset); + ElementC** output_ptr = + reinterpret_cast(kernel_args_ptr + output_offset); + StrideA* stride_a_ptr = + reinterpret_cast(kernel_args_ptr + stride_a_offset); + StrideB* stride_b_ptr = + reinterpret_cast(kernel_args_ptr + stride_b_offset); + StrideC* stride_c_ptr = + reinterpret_cast(kernel_args_ptr + stride_c_offset); + LayoutSFA* layout_SFA = + reinterpret_cast(kernel_args_ptr + layout_SFA_offset); + LayoutSFB* layout_SFB = + reinterpret_cast(kernel_args_ptr + layout_SFB_offset); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // For SFA and SFB tensors layouts + using Sm1xxBlkScaledConfig = + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + TORCH_CHECK( + !zero_start_index_M.has_value() || + zero_start_index_M->dtype() == at::kLong, + "zero_start_index_M must be int64."); + + TORCH_CHECK( + !M_sizes.has_value() || M_sizes->dtype() == at::kLong, + "M_sizes must be int64."); + // When m_offsets is used, XQ is shape [total_M, K]. When zero_start_index_M + // is used, shape is [G, M, K]. + int64_t M = XQ.size(XQ.dim() - 2); + int64_t N = WQ.size(1); + int64_t K = WQ.size(2); + + // Calculate the number of scale elements per group + int64_t num_x_scale_per_group; + int64_t num_w_scale_per_group; + TORCH_CHECK( + x_scale.dim() == 2 || x_scale.dim() == 3, + "x_scale must be either 2D or 3D tensor") + if (x_scale.dim() == 3) { + num_x_scale_per_group = x_scale.size(1) * x_scale.size(2); + } else { + num_x_scale_per_group = x_scale.size(1); + } + TORCH_CHECK( + w_scale.dim() == 2 || w_scale.dim() == 3, + "w_scale must be either 2D or 3D tensor") + if (w_scale.dim() == 3) { + num_w_scale_per_group = w_scale.size(1) * w_scale.size(2); + } else { + num_w_scale_per_group = w_scale.size(1); + } + + int64_t* M_sizes_ptr = reinterpret_cast(M_sizes.value().data_ptr()); + int64_t* starting_row_after_padding_ptr = + reinterpret_cast(starting_row_after_padding.value().data_ptr()); + set_stacked_kernel_args_kernel< + ProblemShape::UnderlyingProblemShape, + ElementA, + ElementB, + ElementC, + ElementComputeEpilogue, + StrideA, + StrideB, + StrideC, + LayoutSFA, + LayoutSFB, + ElementGlobalScale, + Sm1xxBlkScaledConfig><<<1, G, 0, stream>>>( + G, + N, + K, + num_x_scale_per_group, + num_w_scale_per_group, + problem_shape_ptr, + reinterpret_cast(XQ.data_ptr()), + xq_ptr, + reinterpret_cast(WQ.data_ptr()), + wq_ptr, + reinterpret_cast(x_scale.data_ptr()), + x_scale_ptr, + reinterpret_cast(w_scale.data_ptr()), + w_scale_ptr, + reinterpret_cast(output.data_ptr()), + output_ptr, + stride_a_ptr, + stride_b_ptr, + stride_c_ptr, + M_sizes_ptr, + layout_SFA, + layout_SFB, + starting_row_after_padding_ptr); + // Set the number of groups to the kernel to be at most the number of + // non-zero rows. + kernel_groups = int(std::min(M, G)); + + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = + min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id), + 2147483647); // INT_MAX + + using DataTypeA = typename ElementA::DataType; + using DataTypeB = typename ElementB::DataType; + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {kernel_groups, problem_shape_ptr, nullptr}, + {reinterpret_cast(wq_ptr), + stride_b_ptr, + reinterpret_cast(xq_ptr), + stride_a_ptr, + reinterpret_cast(w_scale_ptr), + layout_SFB, + reinterpret_cast(x_scale_ptr), + layout_SFA}, + {{}, nullptr, stride_c_ptr, output_ptr, stride_c_ptr}, + hw_info}; + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + at::Tensor workspace = at::empty(workspace_size, options.dtype(at::kByte)); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize( + arguments, reinterpret_cast(workspace.data_ptr())); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +#endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh new file mode 100644 index 0000000000..e2b76186a2 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh @@ -0,0 +1,103 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace fbgemm_gpu { + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) + +at::Tensor mx8mx8bf16_grouped_128_64_256_1_1_1( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding); + +at::Tensor mx8mx8bf16_grouped_128_128_256_1_1_1( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding); + +at::Tensor mx8mx8bf16_grouped_256_64_256_2_1_1( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding); + +at::Tensor mx8mx8bf16_grouped_256_128_256_2_1_1( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding); + +at::Tensor mx8mx8bf16_grouped_256_256_256_2_1_1( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + int64_t G, + std::optional zero_start_index_M, + std::optional M_sizes, + std::optional starting_row_after_padding); + +template +using Kernel_mx8mx8bf16_grouped = at::Tensor (*)( + InputType, + InputType, + InputType, + InputType, + at::Tensor, + int64_t, + std::optional, + std::optional, + std::optional); + +template +const std::unordered_map>& +get_mx8mx8bf16_grouped_kernels() { + static const std:: + unordered_map> + kernels = { + {"mx8mx8bf16_grouped_128_64_256_1_1_1", + mx8mx8bf16_grouped_128_64_256_1_1_1}, + {"mx8mx8bf16_grouped_128_128_256_1_1_1", + mx8mx8bf16_grouped_128_128_256_1_1_1}, + {"mx8mx8bf16_grouped_256_64_256_2_1_1", + mx8mx8bf16_grouped_256_64_256_2_1_1}, + {"mx8mx8bf16_grouped_256_128_256_2_1_1", + mx8mx8bf16_grouped_256_128_256_2_1_1}, + {"mx8mx8bf16_grouped_256_256_256_2_1_1", + mx8mx8bf16_grouped_256_256_256_2_1_1}, + }; + return kernels; +} + +#endif +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 139d05cc9c..c33136d86a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -64,6 +64,13 @@ at::Tensor f4f4bf16_grouped_stacked( std::optional global_scale = std::nullopt, std::optional starting_row_after_padding = std::nullopt, bool use_mx = true); +at::Tensor mx8mx8bf16_grouped_stacked( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor M_sizes, + std::optional starting_row_after_padding = std::nullopt); at::Tensor f8f8bf16( at::Tensor XQ, at::Tensor WQ, @@ -313,6 +320,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("f4f4bf16", f4f4bf16); m.impl("f4f4bf16_grouped", f4f4bf16_grouped); m.impl("f4f4bf16_grouped_stacked", f4f4bf16_grouped_stacked); + m.impl("mx8mx8bf16_grouped_stacked", mx8mx8bf16_grouped_stacked); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); m.impl("bf16_fast_gemv", bf16_fast_gemv); @@ -368,6 +376,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("f4f4bf16", f4f4bf16); m.impl("f4f4bf16_grouped", f4f4bf16_grouped); m.impl("f4f4bf16_grouped_stacked", f4f4bf16_grouped_stacked); + m.impl("mx8mx8bf16_grouped_stacked", mx8mx8bf16_grouped_stacked); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); m.impl("bf16_fast_gemv", bf16_fast_gemv); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp index ae3ab5851a..dd6f949338 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp @@ -24,6 +24,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "f4f4bf16_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[]? global_scale=None, bool use_mx=True) -> Tensor[]"); m.def( "f4f4bf16_grouped_stacked(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor M_sizes, Tensor? global_scale=None, Tensor? starting_row_after_padding=None, bool use_mx=True) -> Tensor"); + m.def( + "mx8mx8bf16_grouped_stacked(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor M_sizes, Tensor? starting_row_after_padding=None) -> Tensor"); m.def( "f8f8bf16(Tensor XQ, Tensor WQ, Tensor scale, bool use_fast_accum=True) -> Tensor"); m.def( diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index d3d73fe0d1..58773825d7 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -24,9 +24,14 @@ quantize_fp8_block, quantize_fp8_row, supports_float8_fnuz, + to_mxfp8, ) + from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle + if torch.cuda.get_device_capability() >= (10, 0): + from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked + from hypothesis import given, settings, strategies as st # Marlin is currently only supported internally at Meta. @@ -52,8 +57,16 @@ def evaluate_platform_supports_fp8(): return False +def evaluate_platform_supports_mxfp8(): + if torch.cuda.is_available(): + return torch.cuda.get_device_capability() >= (10, 0) + return False + + SUPPORTS_FP8 = evaluate_platform_supports_fp8() +SUPPORTS_MXFP8 = evaluate_platform_supports_mxfp8() + if torch.cuda.is_available() and supports_float8_fnuz( throw_on_hip_incompatibility=(not running_on_github) ): @@ -1214,6 +1227,80 @@ def test_grouped_gemm_2d_3d( # BF16 loopover gemm reference self.bf16_loopover_validate(x_group, W, y_fp8_group, y_bf16_group) + @unittest.skipIf(not SUPPORTS_MXFP8, "MXFP8 not supported on this platform") + @settings(deadline=None) + @given( + G=st.sampled_from([1, 4, 16]), + M=st.sampled_from([2048, 3584]), + N=st.sampled_from([256, 1024, 6144]), + K=st.sampled_from([256, 512, 3584]), + ) + def test_mx_grouped_gemm( + self, + G: int, + M: int, + N: int, + K: int, + ) -> None: + X = torch.randn((G, M, K), dtype=torch.bfloat16, device=self.device) * 0.1 + W = torch.randn((G, N, K), dtype=torch.bfloat16, device=self.device) * 0.01 + + m_values = [i.shape[0] for i in X] + m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=X[0].device) + + wq_list = [] + w_scale_list = [] + for i in range(m_sizes.shape[0]): + w_scale, wq = to_mxfp8(W[i]) + w_scale = _to_blocked(w_scale) + wq_list.append(wq) + w_scale_list.append(w_scale) + wq = torch.stack(wq_list, dim=0).contiguous() + w_scale = torch.stack(w_scale_list, dim=0).contiguous() + + starting_row_after_padding_list = [0] + xq_list = [] + x_scale_list = [] + for i in range(m_sizes.shape[0]): + scale_slice = X[i] + if m_sizes[i].item() != 0: + x_scale, xq = to_mxfp8(scale_slice) + x_scale = _to_blocked(x_scale) + xq_list.append(xq) + x_scale_list.append(x_scale) + starting_row_after_padding_list.append( + starting_row_after_padding_list[i] + + x_scale.numel() // (X[0].shape[1] // 32) + ) + else: + starting_row_after_padding_list.append( + starting_row_after_padding_list[i] + ) + starting_row_after_padding = torch.tensor( + starting_row_after_padding_list, device=xq.device + ) + + xq = torch.cat(xq_list, dim=0).contiguous() + x_scale = torch.cat(x_scale_list, dim=0).contiguous() + x_scale = x_scale.reshape(-1, X[0].shape[-1] // 32) + xq = xq.view(-1, xq.shape[-1]) + + y_mxfp8 = torch.ops.fbgemm.mx8mx8bf16_grouped_stacked( + xq, + wq, + x_scale, + w_scale, + m_sizes, + starting_row_after_padding=starting_row_after_padding, + ) + + y_bf16_group = [] + for i in range(G): + y_bf16_group.append(torch.matmul(X[i], W[i].t())) + y_bf16 = torch.cat(y_bf16_group, dim=0) + + torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2) + @unittest.skipIf( not torch.version.hip, "Only AMD supports torch 3D-2D grouped gemm API",