|
8 | 8 | import triton |
9 | 9 | import triton.language as tl |
10 | 10 |
|
| 11 | +import vllm.envs as envs |
11 | 12 | from vllm import _custom_ops as ops |
12 | 13 | from vllm.logger import init_logger |
13 | 14 |
|
@@ -420,13 +421,12 @@ def fused_experts(hidden_states: torch.Tensor, |
420 | 421 | torch.float32, torch.float16, torch.bfloat16 |
421 | 422 | ] |
422 | 423 |
|
423 | | - M, _ = hidden_states.shape |
| 424 | + num_tokens, _ = hidden_states.shape |
424 | 425 | E, N, _ = w1.shape |
425 | | - |
426 | | - if M > 65536: |
427 | | - # https://github.com/vllm-project/vllm/issues/5938 |
428 | | - raise ValueError("MoE kernel does not support more than 65536 tokens, " |
429 | | - f"but got {M}") |
| 426 | + # We execute the fused_moe kernel in chunks to circumvent this issue: |
| 427 | + # https://github.com/vllm-project/vllm/issues/5938 |
| 428 | + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE |
| 429 | + M = min(num_tokens, CHUNK_SIZE) |
430 | 430 |
|
431 | 431 | if override_config: |
432 | 432 | config = override_config |
@@ -455,51 +455,74 @@ def fused_experts(hidden_states: torch.Tensor, |
455 | 455 | device=hidden_states.device, |
456 | 456 | dtype=hidden_states.dtype) |
457 | 457 |
|
458 | | - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( |
459 | | - topk_ids, config['BLOCK_SIZE_M'], E) |
460 | 458 | compute_type = (tl.bfloat16 |
461 | 459 | if hidden_states.dtype == torch.bfloat16 else tl.float16) |
462 | 460 |
|
463 | | - invoke_fused_moe_kernel(hidden_states, |
464 | | - w1, |
465 | | - intermediate_cache1, |
466 | | - a1_scale, |
467 | | - w1_scale, |
468 | | - topk_weights, |
469 | | - topk_ids, |
470 | | - sorted_token_ids, |
471 | | - expert_ids, |
472 | | - num_tokens_post_padded, |
473 | | - False, |
474 | | - topk_ids.shape[1], |
475 | | - config, |
476 | | - compute_type=compute_type, |
477 | | - use_fp8=use_fp8) |
478 | | - |
479 | | - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) |
480 | | - |
481 | | - invoke_fused_moe_kernel(intermediate_cache2, |
482 | | - w2, |
483 | | - intermediate_cache3, |
484 | | - a2_scale, |
485 | | - w2_scale, |
486 | | - topk_weights, |
487 | | - topk_ids, |
488 | | - sorted_token_ids, |
489 | | - expert_ids, |
490 | | - num_tokens_post_padded, |
491 | | - True, |
492 | | - 1, |
493 | | - config, |
494 | | - compute_type=compute_type, |
495 | | - use_fp8=use_fp8) |
496 | | - |
497 | 461 | if inplace: |
498 | | - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), |
499 | | - dim=1, |
500 | | - out=hidden_states) |
501 | | - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), |
502 | | - dim=1) |
| 462 | + out_hidden_states = hidden_states |
| 463 | + else: |
| 464 | + out_hidden_states = torch.empty_like(hidden_states) |
| 465 | + |
| 466 | + for chunk in range((num_tokens // CHUNK_SIZE) + 1): |
| 467 | + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, |
| 468 | + min((chunk + 1) * CHUNK_SIZE, |
| 469 | + num_tokens)) |
| 470 | + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] |
| 471 | + tokens_in_chunk, _ = curr_hidden_states.shape |
| 472 | + |
| 473 | + if tokens_in_chunk == 0: |
| 474 | + break |
| 475 | + |
| 476 | + if tokens_in_chunk < CHUNK_SIZE: |
| 477 | + # will only happen in the last chunk |
| 478 | + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] |
| 479 | + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] |
| 480 | + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] |
| 481 | + |
| 482 | + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] |
| 483 | + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] |
| 484 | + |
| 485 | + sorted_token_ids, expert_ids, num_tokens_post_padded = ( |
| 486 | + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) |
| 487 | + |
| 488 | + invoke_fused_moe_kernel(curr_hidden_states, |
| 489 | + w1, |
| 490 | + intermediate_cache1, |
| 491 | + a1_scale, |
| 492 | + w1_scale, |
| 493 | + curr_topk_weights, |
| 494 | + curr_topk_ids, |
| 495 | + sorted_token_ids, |
| 496 | + expert_ids, |
| 497 | + num_tokens_post_padded, |
| 498 | + False, |
| 499 | + topk_ids.shape[1], |
| 500 | + config, |
| 501 | + compute_type=compute_type, |
| 502 | + use_fp8=use_fp8) |
| 503 | + |
| 504 | + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) |
| 505 | + |
| 506 | + invoke_fused_moe_kernel(intermediate_cache2, |
| 507 | + w2, |
| 508 | + intermediate_cache3, |
| 509 | + a2_scale, |
| 510 | + w2_scale, |
| 511 | + curr_topk_weights, |
| 512 | + curr_topk_ids, |
| 513 | + sorted_token_ids, |
| 514 | + expert_ids, |
| 515 | + num_tokens_post_padded, |
| 516 | + True, |
| 517 | + 1, |
| 518 | + config, |
| 519 | + compute_type=compute_type, |
| 520 | + use_fp8=use_fp8) |
| 521 | + |
| 522 | + torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), |
| 523 | + dim=1, |
| 524 | + out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) |
| 525 | + return out_hidden_states |
503 | 526 |
|
504 | 527 |
|
505 | 528 | def fused_moe( |
|
0 commit comments