Skip to content

Conversation

@cdutr
Copy link

@cdutr cdutr commented Nov 14, 2025

What does this PR do?

Fixes the QwenImage encoder to properly apply encoder_hidden_states_mask when passed to the model. Previously, the mask parameter was accepted but ignored, causing padding tokens to incorrectly influence attention computation.

Changes

  • Attention mask application: Modified QwenDoubleStreamAttnProcessor2_0 to create a 2D attention mask from the 1D encoder_hidden_states_mask, properly masking text padding tokens while keeping all image tokens unmasked
  • RoPE adjustment: Updated positional embedding computation to use the full padded sequence length when a mask is present, ensuring correct position indices
  • Tests: Added comprehensive tests validating that:
    • Padding tokens are properly isolated and don't affect outputs
    • Masked outputs differ significantly from unmasked outputs
  • Benchmarks: Included performance analysis showing acceptable overhead (<20% for inference, ~19% for training scenarios)

Impact

This fix enables proper Classifier-Free Guidance (CFG) batching with variable-length text sequences, which is common when batching conditional and unconditional prompts together.

Benchmark Results

Scenario Latency (ms) Peak Memory (MB) Throughput (iter/s)
Baseline (no mask) 11.68 ± 0.23 301.5 84.70
Mask all-ones (no padding) 12.01 ± 0.26 301.5 82.34
Mask with padding (CFG) 13.86 ± 0.24 301.5 71.42

Overhead: +2.8% for mask processing without padding, +18.7% with actual padding (realistic CFG scenario)

The higher overhead with padding is expected and acceptable as it represents the cost of properly handling variable-length sequences in batched inference. This is a necessary correctness fix rather than an optimization. Test ran on RTX 4070 12GB.

Fixes #12294


Before submitting

  • This PR fixes a bug in the code
  • This PR adds tests that verify the fix
  • This PR includes benchmarks demonstrating performance impact
  • Did you write any new necessary tests?

Who can review?

@yiyixuxu @sayakpaul - Would appreciate your review, especially regarding the benchmarking approach. I used a custom benchmark rather than BenchmarkMixin because:

  1. This tests a specific bug fix (mask application) rather than optimization strategies
  2. The fix uses synthetic models to isolate the mask handling logic
  3. Standard benchmarks focus on pretrained model performance with different quantization/offloading strategies
  4. The metrics needed are different (latency distribution, throughput) vs standard format (compile/plain time comparison)

Note: The benchmark file is named benchmarking_qwenimage_mask.py (with "benchmarking" prefix) rather than benchmark_qwenimage_mask.py to prevent it from being picked up by run_all.py, since it doesn't use BenchmarkMixin and produces a different CSV schema. If you prefer, I can adapt it to use the standard format instead.

Happy to adjust the approach if you have suggestions!

Improves attention mask handling for QwenImage transformer by:
- Adding support for variable-length sequence masking
- Implementing dynamic attention mask generation from encoder_hidden_states_mask
- Ensuring RoPE embedding works correctly with padded sequences
- Adding comprehensive test coverage for masked input scenarios

Performance and flexibility benefits:
- Enables more efficient processing of sequences with padding
- Prevents padding tokens from contributing to attention computations
- Maintains model performance with minimal overhead
Improves file naming convention for the Qwen image mask performance benchmark script

Enhances code organization by using a more descriptive and consistent filename that clearly indicates the script's purpose
@sayakpaul
Copy link
Member

@cdutr it's great that you have also included the benchmarking script for fullest transparency. But we can remove that from this PR and instead have that as a GitHub gist.

The benchmark numbers make sense to me. Some comments:

  • Could we also check the performance with torch.compile?
  • Could we also see some image outputs with and without the changes introduced in this PR?

Also, I think a natural next step would be see how well this performs when combined with FA varlen. WDYT?

@naykun what do you think about the changes?

@sayakpaul sayakpaul requested a review from yiyixuxu November 20, 2025 03:59
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Qwen-image] encoder_hidden_states_mask is not used

3 participants