Skip to content

Conversation

@Mhmd-Hisham
Copy link
Contributor

@Mhmd-Hisham Mhmd-Hisham commented Sep 4, 2025

This PR improves the performance of NF4 and FP4 dequantization in the kDequantizeBlockwise CUDA kernel by replacing conditional branches with lookup tables.

The current use of if conditions in the implementations of dDequantizeNF4 and dDequantizeFP4Tree introduce branching and reduce occupancy in CUDA kernels.

To work around this, we could treat the final quantized values as indices and use a lookup table for dequantization. However, this requires quantization functions (dQuantizeFP4 and dQuantizeNF4) to map values in strictly increasing order.

This is already done for NF4. However, for FP4 (dQuantizeFP4), the quantized values were out of order. This was fixed by reordering the quantized values in dQuantizeFP4. Actually, there is no need to sort quantized values for dequantization, the sorting was required for a branchless quantization kernel that I was working on.

Benchmarking methodology

  • Stress test: I built a custom stress test calling bitsandbytes.functional.dequantize_4bit with different input sizes, 1000 iterations each, and measured latency via torch.cuda.Event.
  • Kernel profiling: NVIDIA Nsight Compute (NCU) on the custom stress test with 100 iterations to isolate the kDequantizeBlockwise kernel from background noise.
  • Inference benchmark: modified inference_benchmark.py for stable results.

Benchmarks were run on an H100 80GB SXM GPU (from Nebius.AI) with warmup runs, fixed GPU clocks, each run executed in an isolated Docker container, and multiple repetitions (t-test verified).

Results

H100_Inference_Benchmark

Optimum Benchmark (Inference benchmark Llama-3.1-70B on H100):

  • For decode, for batch sizes greater than 1, throughput is improved by ~33% and latency is reduced by ~25%.
  • For prefill, throughput is improved by ~12% and latency is reduced by ~10% when the batch size is 1. The effect diminishes with larger batch sizes so further investigation may be needed here.
stress_test_plot

Stress Test results:

  • Stress test (torch.cuda.Event): ~30% latency reduction.
  • Kernel-level speedup (NCU on stress test): ~40% latency reduction.

List of changes

  • Improved decode throughput by ~33% and reduced latency by ~25% for batch sizes > 1, benchmarked Llama-3.1-70B on H100 80GB SXM.
  • Improved prefill throughput by ~12% and reduced latency by ~10% for batch size 1.
  • Reduced dequantization latency for NF4 by ~40% overall.

The complete analysis notebook, stress test scripts, and Docker and environment setup could be found here.

@Mhmd-Hisham Mhmd-Hisham changed the title [CUDA] Branchless NF4/FP4 dequantization for faster kDequantizeBlockwise [CUDA] Branchless NF4/FP4 kDequantizeBlockwise kernel for faster dequantization Sep 5, 2025
@matthewdouglas
Copy link
Member

Hi @Mhmd-Hisham!

I really appreciate the effort and analysis here! The timing is impeccable too, as I was also working on improving this kernel for the same reasons this week! Please see the PR that I've opened: #1747.

I still need to do the end-to-end benchmarking, but I'd like to follow the same methodology that you're using here so we can compare directly.

One of the reasons why I would prefer a change more like mine is that we don't need to reassign the values for the FP4 LUT. Doing so breaks backwards compatibility with existing checkpoints that have been published.

@matthewdouglas matthewdouglas added the CUDA Issues and PRs related to the CUDA backend, excluding installation/support help. label Sep 5, 2025
@Mhmd-Hisham
Copy link
Contributor Author

Mhmd-Hisham commented Sep 5, 2025

Hi @matthewdouglas, thanks for the quick and thorough feedback!

I’d be glad to collaborate on this. To ensure consistent comparisons, you could use my complete benchmarking setup linked here. Alternatively, if you can grant me access to your setup, I’d be happy to run the benchmarks myself and compare both branches directly.

Regarding the backwards compatibility concern with FP4: actually, I was reordering FP4 values because I'm working on a branchless quantization kernel that requires the values to be ordered. I didn't include it here since I'm still developing a benchmarking script for model load + quantization (without including disk I/O). However, I can workaround it for quantization and keep FP4 order as is if it breaks backwards compatibility. We can discuss the branchless quantization work separately in another PR to keep things clear, but plesae let me know if this is something you are working on so we can collaborate on it as well.

For dequantization, there is no need to sort FP4 values. Apologies for the confusion.

@Mhmd-Hisham Mhmd-Hisham force-pushed the cuda-branchless-dequantization-float32-lut branch from 4e2ef54 to 2c4927d Compare September 6, 2025 02:34
@github-actions
Copy link

github-actions bot commented Sep 8, 2025

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Mhmd-Hisham
Copy link
Contributor Author

Hi @matthewdouglas,

I went ahead and ran another benchmark on Nebius.AI with the same setup as this PR for your implementation in #1747.

Here are the results compared side by side:

side_by_side_comparison

Note: cuda-branchless-dequantization-shfl-sync is #1747.

Overall, my implementation is slightly faster by 2~3% for decode latency, ~4% for decode throughput, and by ~1% for prefill (both throughput and latency). You can check the complete analysis for both branches here.

@matthewdouglas
Copy link
Member

@Mhmd-Hisham This looks great, well done! I'm going to run some more tests on this on a few additional GPUs (T4, A100, L40S, RTX 4090, and possibly B200). I'll get back to you next week on this, but I think we'll probably end up merging this and closing #1747!

@matthewdouglas matthewdouglas added this to the v0.48.0 milestone Sep 15, 2025
@matthewdouglas
Copy link
Member

matthewdouglas commented Sep 18, 2025

After testing we've found that these improvements hold up on A100, H100, and B200. I've additionally done some testing on T4, L40S, and A10G. Like #1747, the improvement is less pronounced on those GPUs, especially as the problem sizes grow.

I'm merging this one and closing the other PR.

Thanks @Mhmd-Hisham!

@matthewdouglas matthewdouglas merged commit b1f80b8 into bitsandbytes-foundation:main Sep 18, 2025
47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CUDA Issues and PRs related to the CUDA backend, excluding installation/support help.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants