Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul,
)
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, cdiv

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
Expand Down Expand Up @@ -117,14 +117,9 @@ def bench_fp8(
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)

def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y

block_scale_a = torch.rand(
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
)
block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32)
block_scale_b = torch.rand(
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32
)
block_scale_a_M_major = block_scale_a.t().contiguous().t()
block_scale_b_K_major = block_scale_b.t().contiguous().t()
Expand Down
5 changes: 1 addition & 4 deletions tests/kernels/attention/test_mla_decode_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@

import vllm._custom_ops as ops
from vllm.platforms import current_platform


def cdiv(a, b):
return (a + b - 1) // b
from vllm.utils import cdiv


def ref_mla(
Expand Down
5 changes: 1 addition & 4 deletions tests/kernels/attention/test_triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import torch

from vllm.attention.ops.triton_decode_attention import decode_attention_fwd


def cdiv(a, b):
return (a + b - 1) // b
from vllm.utils import cdiv


@pytest.mark.parametrize("B", [3, 5])
Expand Down
9 changes: 4 additions & 5 deletions tests/neuron/1_core/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
import torch.nn.functional as F

from vllm.utils import cdiv


class BlockDiagonalCausalFromBottomRightMask:

Expand Down Expand Up @@ -398,11 +400,8 @@ def test_contexted_kv_attention(
assert (large_tile_size >= B_P_SIZE
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"

def ceil_div(a, b):
return (a + b - 1) // b

def pad_to_multiple(a, b):
return ceil_div(a, b) * b
return cdiv(a, b) * b

def pad_to_next_power_of_2(a):
assert a > 0
Expand All @@ -411,7 +410,7 @@ def pad_to_next_power_of_2(a):
# calculate input shapes
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
num_active_blocks = cdiv(context_lens, block_size).sum().item()
num_active_blocks = pad_to_multiple(num_active_blocks,
large_tile_size // block_size)
context_kv_len = num_active_blocks * block_size
Expand Down
15 changes: 6 additions & 9 deletions vllm/attention/ops/nki_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from neuronxcc import nki
from neuronxcc.nki.language import par_dim


def ceil_div(a, b):
return (a + b - 1) // b
from vllm.utils import cdiv


def is_power_of_2(x):
Expand All @@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
(num_tiles, num_blocks_per_tile))

block_tables_sbuf = nl.zeros(
(ceil_div(num_tiles,
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
(cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
dtype=nl.int32,
)
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(num_blocks_per_tile)[None, :]
block_tables_sbuf[i, i_p, i_f] = nl.load(
Expand Down Expand Up @@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load(
assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"

num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
block_tables_transposed = nl.ndarray(
(
num_loads,
Expand Down Expand Up @@ -165,7 +162,7 @@ def load_kv_tile_from_cache(
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
"""
# load key cache
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
for load_idx in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
Expand Down Expand Up @@ -605,7 +602,7 @@ def flash_paged_attention(
)

for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
cur_k_tile = nl.ndarray(
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@

from vllm import _custom_ops as ops
from vllm.triton_utils import tl, triton
from vllm.utils import round_up


def ceil_div(a, b):
return (a + b - 1) // b
from vllm.utils import cdiv, round_up


@triton.jit
Expand Down Expand Up @@ -115,7 +111,7 @@ def moe_align_block_size_triton(
cumsum = torch.zeros((num_experts + 1, ),
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
tokens_per_thread = cdiv(numel, num_experts)

moe_align_block_size_stage1[grid](
topk_ids,
Expand Down
9 changes: 3 additions & 6 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils import cdiv, direct_register_custom_op

logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
Expand Down Expand Up @@ -158,12 +158,9 @@ def apply_w8a8_block_fp8_linear(
if current_platform.is_cuda():
if current_platform.has_device_capability(100):

def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y

use_cutlass = cutlass_block_fp8_supported and (
ceil_div(weight.shape[0], 128) == weight_scale.shape[0]
and ceil_div(weight.shape[1], 128) == weight_scale.shape[1])
cdiv(weight.shape[0], 128) == weight_scale.shape[0]
and cdiv(weight.shape[1], 128) == weight_scale.shape[1])
else:
# TODO: update this after switching to public sm90 block scale gemm
# as it also supports weight.shape % 128 != 0
Expand Down