diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 115b92539f96..b4b91eda2844 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -17,13 +17,8 @@ from utils import ArgPool, Bench, CudaGraphBenchParams from weight_shapes import WEIGHT_SHAPES -from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT -from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) @@ -167,69 +162,25 @@ class OpType(Enum): """ LoRA Ops to benchmark and its properties. """ - SGMV_SHRINK = auto() - BGMV_SHRINK = auto() - SGMV_EXPAND = auto() - BGMV_EXPAND = auto() - BGMV_EXPAND_SLICE = auto() - V1_SHRINK = auto() - V1_EXPAND = auto() + LORA_SHRINK = auto() + LORA_EXPAND = auto() @staticmethod def from_str(s: str) -> "OpType": - if s.lower() == 'sgmv_shrink': - return OpType.SGMV_SHRINK - if s.lower() == 'sgmv_expand': - return OpType.SGMV_EXPAND - if s.lower() == 'bgmv_shrink': - return OpType.BGMV_SHRINK - if s.lower() == 'bgmv_expand': - return OpType.BGMV_EXPAND - if s.lower() == "bgmv_expand_slice": - return OpType.BGMV_EXPAND_SLICE - if s.lower() == "v1_shrink": - return OpType.V1_SHRINK - if s.lower() == "v1_expand": - return OpType.V1_EXPAND + if s.lower() == "lora_shrink": + return OpType.LORA_SHRINK + if s.lower() == "lora_expand": + return OpType.LORA_EXPAND raise ValueError(f"Unrecognized str {s} to convert to OpType") def is_shrink_fn(self) -> bool: - return self in [ - OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK - ] + return self in [OpType.LORA_SHRINK] def is_expand_fn(self) -> bool: - return self in [ - OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND - ] - - def is_prefill_op(self) -> bool: - return self in [ - OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK, - OpType.V1_EXPAND - ] - - def is_decode_op(self) -> bool: - return self in [ - OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE, - OpType.V1_SHRINK, OpType.V1_EXPAND - ] - - def is_expand_slice_fn(self) -> bool: - return self in [OpType.BGMV_EXPAND_SLICE] + return self in [OpType.LORA_EXPAND] def num_slices(self) -> list[int]: - if self in [ - OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK, - OpType.V1_EXPAND - ]: - # SGMV kernels and v1 kernels supports slices - return [1, 2, 3] - if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]: - return [1] - if self in [OpType.BGMV_EXPAND_SLICE]: - return [2, 3] - raise ValueError(f"Unrecognized OpType {self}") + return [1, 2, 3] def mkn(self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int) -> tuple[int, int, int]: @@ -239,7 +190,7 @@ def mkn(self, batch_size: int, seq_length: int, hidden_size: int, k = hidden_size n = lora_rank else: - assert self.is_expand_fn() or self.is_expand_slice_fn() + assert self.is_expand_fn() m = num_tokens k = lora_rank n = hidden_size @@ -254,7 +205,7 @@ def matmul_dtypes( if self.is_shrink_fn(): return op_dtype, op_dtype, torch.float32 else: - assert self.is_expand_fn() or self.is_expand_slice_fn() + assert self.is_expand_fn() return torch.float32, op_dtype, op_dtype def matmul_shapes( @@ -268,43 +219,19 @@ def matmul_shapes( m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank) b_shape = (num_loras, n, k) # col-major - if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]: - # SGMV shrink and V1 shrink kernels support num_slices inherently - # in the kernel. + if self in [OpType.LORA_SHRINK]: + # LoRA shrink kernels support num_slices inherently in the kernel. return ((m, k), b_shape, (num_slices, m, n)) - if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]: - # SGMV expand and V1 expand kernels support num_slices inherently - # in the kernel + if self in [OpType.LORA_EXPAND]: + # LoRA expand kernels support num_slices inherently in the kernel return ((num_slices, m, k), b_shape, (m, n * num_slices)) - if self == OpType.BGMV_SHRINK: - return ((m, k), b_shape, (m, n)) - if self == OpType.BGMV_EXPAND: - return ((m, k), b_shape, (m, n)) - if self == OpType.BGMV_EXPAND_SLICE: - return ((num_slices, m, k), b_shape, (m, n * num_slices)) - raise ValueError(f"Unrecognized op_type {self}") def bench_fn(self) -> Callable: - - def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]): - for x in kwargs_list: - bgmv_expand_slice(**x) - - if self == OpType.SGMV_SHRINK: - return sgmv_shrink - if self == OpType.SGMV_EXPAND: - return sgmv_expand - if self == OpType.BGMV_SHRINK: - return bgmv_shrink - if self == OpType.BGMV_EXPAND: - return bgmv_expand - if self == OpType.BGMV_EXPAND_SLICE: - return emulate_bgmv_expand_slice - if self == OpType.V1_SHRINK: - return v1_shrink - if self == OpType.V1_EXPAND: - return v1_expand + if self == OpType.LORA_SHRINK: + return lora_shrink + if self == OpType.LORA_EXPAND: + return lora_expand raise ValueError(f"Unrecognized optype {self}") @@ -318,34 +245,13 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, """ w_dtype = lora_weights[0].dtype num_slices = len(lora_weights) - if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]: + if self in [OpType.LORA_SHRINK]: for slice_idx in range(num_slices): ref_group_gemm(ref_out=output[slice_idx, :], input=input, lora_weights=lora_weights[slice_idx], **kwargs) - elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]: - hidden_size = lora_weights[0].shape[1] - for slice_idx in range(num_slices): - slice_offset = slice_idx * hidden_size - ref_group_gemm( - ref_out=output[:, slice_offset:slice_offset + hidden_size], - input=input[slice_idx].clone().to(dtype=w_dtype), - lora_weights=lora_weights[slice_idx], - **kwargs) - elif self == OpType.BGMV_SHRINK: - assert num_slices == 1 - ref_group_gemm(ref_out=output, - input=input, - lora_weights=lora_weights[0], - **kwargs) - elif self == OpType.BGMV_EXPAND: - assert num_slices == 1 - ref_group_gemm(ref_out=output, - input=input.clone().to(dtype=w_dtype), - lora_weights=lora_weights[0], - **kwargs) - elif self == OpType.BGMV_EXPAND_SLICE: + elif self in [OpType.LORA_EXPAND]: hidden_size = lora_weights[0].shape[1] for slice_idx in range(num_slices): slice_offset = slice_idx * hidden_size @@ -411,13 +317,11 @@ class BenchmarkTensors: input: torch.Tensor lora_weights_lst: list[torch.Tensor] output: torch.Tensor - # metadata tensors + # LoRA kernel metadata + lora_kernel_meta: LoRAKernelMeta + # Metadata tensors used in testing correctness seq_lens: torch.Tensor - seq_start_loc: torch.Tensor prompt_lora_mapping: torch.Tensor - token_lora_mapping: torch.Tensor - # v1 kernel metadata - v1_kernel_meta: Optional[V1KernelMeta] = None def io_types(self) -> str: return (f"{dtype_to_str(self.input.dtype)}x" @@ -444,35 +348,29 @@ def make(ctx: BenchmarkContext, assert ctx.num_active_loras <= ctx.num_loras total_tokens = ctx.batch_size * ctx.seq_length + # Make metadata tensors involved in correctness testing. # Prepare seq lens tensor seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size, )) - # Prepare seq_start_loc tensor - seq_start_loc_tensor = torch.cumsum(torch.tensor( - [0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), - dim=0) assert total_tokens == seq_len_tensor.sum() # Prepare prompt lora indices tensor prompt_lora_indices_tensor = make_prompt_lora_mapping( ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu") - # Prepare token lora indices tensor + + # Make LoRAKernelMeta token_lora_indices_tensor = make_token_lora_mapping( total_tokens, ctx.batch_size, prompt_lora_indices_tensor, seq_len_tensor, "cpu") - - v1_kernel_meta = None - if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]: - v1_kernel_meta = V1KernelMeta.make( - max_loras=ctx.num_loras, - max_num_tokens=token_lora_indices_tensor.size(0), - device="cpu") - v1_kernel_meta.prepare_tensors( - token_lora_mapping=token_lora_indices_tensor) + lora_kernel_meta = LoRAKernelMeta.make( + max_loras=ctx.num_loras, + max_num_tokens=token_lora_indices_tensor.size(0), + device="cpu") + lora_kernel_meta.prepare_tensors( + token_lora_mapping=token_lora_indices_tensor) return BenchmarkTensors(input_tensor, lora_weights, output_tensor, - seq_len_tensor, seq_start_loc_tensor, - prompt_lora_indices_tensor, - token_lora_indices_tensor, v1_kernel_meta) + lora_kernel_meta, seq_len_tensor, + prompt_lora_indices_tensor) def sanity_check(self) -> None: """ @@ -482,9 +380,9 @@ def sanity_check(self) -> None: # check metadata tensors assert torch.sum(self.seq_lens) == num_tokens num_seqs = self.seq_lens.shape[0] - assert self.seq_start_loc.shape[0] == num_seqs + #assert self.seq_start_loc.shape[0] == num_seqs assert self.prompt_lora_mapping.shape[0] == num_seqs - assert self.token_lora_mapping.shape[0] == num_tokens + assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens def to_device(self, device: str): """ @@ -499,220 +397,27 @@ def to_device(tensor: torch.Tensor): self.input = to_device(self.input) self.output = to_device(self.output) self.seq_lens = to_device(self.seq_lens) - self.seq_start_loc = to_device(self.seq_start_loc) self.prompt_lora_mapping = to_device(self.prompt_lora_mapping) - self.token_lora_mapping = to_device(self.token_lora_mapping) for i in range(len(self.lora_weights_lst)): self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) - # v1 meta - if self.v1_kernel_meta: - for field_name in V1KernelMeta.__dataclass_fields__: - field = getattr(self.v1_kernel_meta, field_name) - assert isinstance(field, torch.Tensor) - setattr(self.v1_kernel_meta, field_name, to_device(field)) + # LoRA meta + for field_name in LoRAKernelMeta.__dataclass_fields__: + field = getattr(self.lora_kernel_meta, field_name) + assert isinstance(field, torch.Tensor) + setattr(self.lora_kernel_meta, field_name, to_device(field)) def metadata(self) -> tuple[int, int, int]: """ Return num_seqs, num_tokens and max_seq_len """ num_seqs = self.seq_lens.shape[0] - num_tokens = self.token_lora_mapping.shape[0] + num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0] max_seq_len = torch.max(self.seq_lens).item() num_slices = len(self.lora_weights_lst) return num_seqs, num_tokens, max_seq_len, num_slices - def convert_to_sgmv_benchmark_tensors(self): - """ - For sgmv punica kernels, when consecutive sequences have the - same LoRA ID, we just merge them together. - This happens in punica.py::compute_metadata - """ - - # Collapse seq_lens and seq_start_loc - _, seq_lens = torch.unique_consecutive(self.token_lora_mapping, - return_counts=True) - cum_result = torch.cumsum(seq_lens, dim=0) - seq_start_loc = torch.zeros_like(seq_lens) - seq_start_loc[1:].copy_(cum_result[:-1]) - - # Collapse prompt mapping - prompt_lora_mapping = torch.unique_consecutive( - self.prompt_lora_mapping) - - assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \ - f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}" - - self.prompt_lora_mapping = prompt_lora_mapping.to( - dtype=self.prompt_lora_mapping.dtype) - self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype) - self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype) - - def as_sgmv_shrink_kwargs(self) -> dict[str, Any]: - self.convert_to_sgmv_benchmark_tensors() - self.sanity_check() - self.to_device(self.input.device) - - num_seqs, num_tokens, max_seq_len, num_slices = self.metadata() - - # Sanity check matrix shapes. - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape - # Expected input shape [num_tokens, hidden_size] - assert len(i_shape) == 2 - assert i_shape[0] == num_tokens - hidden_size = i_shape[1] - # Expected lora weight shape [num_loras, lora_rank, hidden_size] - assert len(lw_shape) == 3 - assert lw_shape[2] == hidden_size - lora_rank = lw_shape[1] - # Expected output shape [num_slices, num_tokens, lora_rank] - assert len(o_shape) == 3 - assert o_shape == (num_slices, num_tokens, lora_rank) - - return { - 'inputs': self.input, - 'lora_a_weights': self.lora_weights_lst, - 'output_tensor': self.output, - 'b_seq_start_loc': self.seq_start_loc, - 'seq_len_tensor': self.seq_lens, - 'lora_indices_tensor': self.prompt_lora_mapping, - 'batches': num_seqs, - 'max_seq_length': max_seq_len, - 'token_nums': num_tokens, - 'scaling': 1.0, - } - - def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: - - self.convert_to_sgmv_benchmark_tensors() - self.sanity_check() - self.to_device(self.input.device) - - num_seqs, num_tokens, max_seq_len, num_slices = self.metadata() - - # Sanity check matrix shapes. - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape - # Expected input shape : [num_slices, num_tokens, lora_rank] - assert len(i_shape) == 3 - assert i_shape[0] == num_slices - assert i_shape[1] == num_tokens - lora_rank = i_shape[2] - # Expected lora weight shape : [num_lora, hidden_size, lora_rank] - assert len(lw_shape) == 3 - assert lw_shape[2] == lora_rank - hidden_size = lw_shape[1] - # Expected output shape : [num_tokens, hidden_size * num_slices] - assert len(o_shape) == 2 - assert o_shape == (num_tokens, hidden_size * num_slices) - - return { - 'inputs': self.input, - 'lora_b_weights': self.lora_weights_lst, - 'output_tensor': self.output, - 'b_seq_start_loc': self.seq_start_loc, - 'seq_len_tensor': self.seq_lens, - 'lora_indices_tensor': self.prompt_lora_mapping, - 'batches': num_seqs, - 'max_seq_length': max_seq_len, - 'token_nums': num_tokens, - 'offset_start': 0, - 'add_inputs': add_inputs, - } - - def as_bgmv_shrink_kwargs(self) -> dict[str, Any]: - assert len(self.lora_weights_lst) == 1 - self.to_device(self.input.device) - - _, num_tokens, _, _ = self.metadata() - # Sanity check shapes - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape - # Expected input shape [num_tokens, hidden_size] - assert len(i_shape) == 2 - assert i_shape[0] == num_tokens - hidden_size = i_shape[1] - # Expected lora weight shape [num_loras, lora_rank, hidden_size] - assert len(lw_shape) == 3 - assert lw_shape[2] == hidden_size - lora_rank = lw_shape[1] - # Expected output shape [num_tokens, lora_rank] - assert len(o_shape) == 2 - assert o_shape == (num_tokens, lora_rank) - - return { - 'inputs': self.input, - 'lora_a_weights': self.lora_weights_lst[0], - 'output_tensor': self.output, - 'lora_indices_tensor': self.token_lora_mapping, - 'scaling': 1.0 - } - - def as_bgmv_expand_kwargs(self, add_inputs: bool): - assert len(self.lora_weights_lst) == 1 - self.to_device(self.input.device) - - _, num_tokens, _, _ = self.metadata() - # Sanity check shapes - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape - # Expected input shape [num_tokens, lora_rank] - assert len(i_shape) == 2 - assert i_shape[0] == num_tokens - lora_rank = i_shape[1] - # Expected lora weight shape [num_loras, hidden_size, lora_rank] - assert len(lw_shape) == 3 - assert lw_shape[2] == lora_rank - hidden_size = lw_shape[1] - # Expected output shape [num_tokens, hidden_size] - assert len(o_shape) == 2 - assert o_shape == (num_tokens, hidden_size) - - return { - 'inputs': self.input, - 'lora_b_weights': self.lora_weights_lst[0], - 'output_tensor': self.output, - 'lora_indices_tensor': self.token_lora_mapping, - 'add_inputs': add_inputs - } - - def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]: - - _, num_tokens, _, num_slices = self.metadata() - # Sanity check shapes - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape - # Expected input shape [num_slices, num_tokens, lora_rank] - assert len(i_shape) == 3 - assert i_shape[0] == num_slices - assert i_shape[1] == num_tokens - lora_rank = i_shape[2] - # Expected lora weight shape [num_loras, hidden_size, lora_rank] - assert len(lw_shape) == 3 - assert lw_shape[2] == lora_rank - hidden_size = lw_shape[1] - # Expected output shape [num_tokens, hidden_size * num_slices] - assert len(o_shape) == 2 - assert o_shape == (num_tokens, hidden_size * num_slices) - - self.to_device(self.input.device) - - kwargs_list = [] - for i in range(num_slices): - kwargs_list.append({ - 'inputs': self.input[i], - 'lora_b_weights': self.lora_weights_lst[i], - 'output_tensor': self.output, - 'lora_indices_tensor': self.token_lora_mapping, - 'slice_offset': i * hidden_size, - 'slice_size': hidden_size, - 'add_inputs': add_inputs, - }) - return {'kwargs_list': kwargs_list} - - def as_v1_shrink_kwargs(self) -> dict[str, Any]: - assert self.v1_kernel_meta is not None + def as_lora_shrink_kwargs(self) -> dict[str, Any]: self.sanity_check() self.to_device(self.input.device) @@ -737,17 +442,16 @@ def as_v1_shrink_kwargs(self) -> dict[str, Any]: 'inputs': self.input, 'lora_a_weights': self.lora_weights_lst, 'output_tensor': self.output, - 'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping, + 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping, 'token_indices_sorted_by_lora_ids': - self.v1_kernel_meta.token_indices_sorted_by_lora_ids, - 'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora, - 'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc, - 'lora_ids': self.v1_kernel_meta.active_lora_ids, + self.lora_kernel_meta.token_indices_sorted_by_lora_ids, + 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora, + 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc, + 'lora_ids': self.lora_kernel_meta.active_lora_ids, 'scaling': 1.0, } - def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: - assert self.v1_kernel_meta is not None + def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: self.sanity_check() self.to_device(self.input.device) @@ -773,12 +477,12 @@ def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: 'inputs': self.input, 'lora_b_weights': self.lora_weights_lst, 'output_tensor': self.output, - 'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping, + 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping, 'token_indices_sorted_by_lora_ids': - self.v1_kernel_meta.token_indices_sorted_by_lora_ids, - 'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora, - 'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc, - 'lora_ids': self.v1_kernel_meta.active_lora_ids, + self.lora_kernel_meta.token_indices_sorted_by_lora_ids, + 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora, + 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc, + 'lora_ids': self.lora_kernel_meta.active_lora_ids, 'offset_start': 0, 'add_inputs': add_inputs, } @@ -791,20 +495,10 @@ def bench_fn_kwargs(self, else: assert add_inputs is not None - if op_type == OpType.SGMV_SHRINK: - return self.as_sgmv_shrink_kwargs() - if op_type == OpType.SGMV_EXPAND: - return self.as_sgmv_expand_kwargs(add_inputs) - if op_type == OpType.BGMV_SHRINK: - return self.as_bgmv_shrink_kwargs() - if op_type == OpType.BGMV_EXPAND: - return self.as_bgmv_expand_kwargs(add_inputs) - if op_type == OpType.BGMV_EXPAND_SLICE: - return self.as_bgmv_expand_slice_kwargs(add_inputs) - if op_type == OpType.V1_SHRINK: - return self.as_v1_shrink_kwargs() - if op_type == OpType.V1_EXPAND: - return self.as_v1_expand_kwargs(add_inputs) + if op_type == OpType.LORA_SHRINK: + return self.as_lora_shrink_kwargs() + if op_type == OpType.LORA_EXPAND: + return self.as_lora_expand_kwargs(add_inputs) raise ValueError(f"Unrecognized optype {self}") def test_correctness(self, op_type: OpType, @@ -993,10 +687,6 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): for bench_ctx in bench_ctxs: for seq_len in args.seq_lengths: bench_ops: list[OpType] = args.op_types - if seq_len > 1: - # bench only prefill ops - bench_ops = [op for op in args.op_types if op.is_prefill_op()] - seq_len_timers = [] for bench_op in bench_ops: for num_slices in bench_op.num_slices(): @@ -1206,13 +896,13 @@ def add_common_command_args(p: argparse.ArgumentParser): {use_cuda_graph_recommendation()} list_bench example: - python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 + python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 model_bench example: - python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 + python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 range_bench example: - python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 + python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 """, # noqa: E501 formatter_class=argparse.RawTextHelpFormatter) diff --git a/tests/lora/test_punica_ops.py b/tests/lora/test_punica_ops.py index a412a80dd70a..726d0c5f2f0d 100644 --- a/tests/lora/test_punica_ops.py +++ b/tests/lora/test_punica_ops.py @@ -4,18 +4,13 @@ import pytest import torch -import vllm.lora.ops.triton_ops # noqa: F401 -import vllm.lora.ops.triton_ops.v1 # noqa: F401 -from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) +import vllm.lora.ops.torch_ops as torch_ops +import vllm.lora.ops.triton_ops as triton_ops +from vllm.lora.ops.triton_ops import LoRAKernelMeta from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT -from vllm.lora.ops.triton_ops.v1 import V1KernelMeta from vllm.platforms import current_platform -from .utils import (PunicaTensors, assert_close, generate_data, - generate_data_for_expand_nslices, - generate_data_for_nslices) +from .utils import PunicaTensors, assert_close, generate_data_for_nslices # Utility shrink and expand operations used as reference implementations. @@ -26,10 +21,10 @@ def sgmv_shrink_for_nslices( prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int, num_tokens: int, scaling: float): """ - Wrapper around sgmv_shrink that handles any nslices. + Wrapper around torch_ops.sgmv_shrink that handles any nslices. """ for index in range(nslices): - sgmv_shrink( + torch_ops.sgmv_shrink( inputs_tensor, lora_weights_lst[index], out_tensor[index], @@ -53,11 +48,11 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int, max_seq_length: int, num_tokens: int, add_inputs: bool) -> None: """ - Wrapper around sgmv_expand that handles any nslices. + Wrapper around torch_ops.sgmv_expand that handles any nslices. """ if nslices == 1: # Verify the torch's sgmv_expand op - sgmv_expand( + torch_ops.sgmv_expand( inputs_tensor[0], lora_weights_lst[0], out_tensor, @@ -73,7 +68,7 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int, slice_offset = 0 for index in range(nslices): lora_weights = lora_weights_lst[index] - sgmv_expand_slice( + torch_ops.sgmv_expand_slice( inputs_tensor[index], lora_weights, out_tensor, @@ -93,12 +88,13 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int, _dict_lock = Lock() -def check_shrink_kernels(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, dtype: torch.dtype, - device: str, seq_length: int, scaling: float): +def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, + dtype: torch.dtype, device: str, seq_length: int, + scaling: float): """ - Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a - reference implementation. + Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink + kernels. """ data: PunicaTensors = generate_data_for_nslices( batches, @@ -118,35 +114,24 @@ def check_shrink_kernels(batches: int, num_loras: int, rank: int, data.prompt_lora_mapping, batches, max_seq_length, token_nums) - # Setup metadata information for the V1 kernel. - v1_meta = V1KernelMeta.make(max_loras=num_loras, - max_num_tokens=token_nums, - device='cuda') - v1_meta.prepare_tensors(data.token_lora_mapping) + # Setup metadata information for the LoRA kernel. + lora_meta = LoRAKernelMeta.make(max_loras=num_loras, + max_num_tokens=token_nums, + device='cuda') + lora_meta.prepare_tensors(data.token_lora_mapping) ref_out_tensor = data.ref_out_tensor - sgmv_out_tensor = data.our_out_tensor - v1_out_tensor = data.our_out_tensor.clone() + out_tensor = data.our_out_tensor.clone() # Preventing cache error pointer. with _dict_lock: - # SGMV shrink kernel + # lora_shrink kernel _LORA_A_PTR_DICT.clear() - torch.ops.vllm.sgmv_shrink( + triton_ops.lora_shrink( data.inputs_tensor, data.lora_weights, - sgmv_out_tensor, - *sgmv_meta_args, - scaling, - ) - - # V1 shrink kernel - _LORA_A_PTR_DICT.clear() - torch.ops.vllm.v1_shrink( - data.inputs_tensor, - data.lora_weights, - v1_out_tensor, - *v1_meta.meta_args(token_nums=token_nums), + out_tensor, + *lora_meta.meta_args(token_nums=token_nums), scaling, ) @@ -160,16 +145,16 @@ def check_shrink_kernels(batches: int, num_loras: int, rank: int, scaling, ) - assert_close(sgmv_out_tensor, ref_out_tensor) - assert_close(v1_out_tensor, ref_out_tensor) + assert_close(out_tensor, ref_out_tensor) -def check_expand_kernels(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, dtype: torch.dtype, - device: str, seq_length: int, add_inputs: bool): +def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, + dtype: torch.dtype, device: str, seq_length: int, + add_inputs: bool): """ - Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a - reference implementation. + Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand + kernels. """ data: PunicaTensors = generate_data_for_nslices( batches, @@ -190,37 +175,25 @@ def check_expand_kernels(batches: int, num_loras: int, rank: int, data.prompt_lora_mapping, batches, max_seq_length, token_nums) - # Setup metadata information for the V1 kernel. - v1_meta = V1KernelMeta.make(max_loras=num_loras, - max_num_tokens=token_nums, - device='cuda') - v1_meta.prepare_tensors(data.token_lora_mapping) + # Setup metadata information for the LoRA kernel. + lora_meta = LoRAKernelMeta.make(max_loras=num_loras, + max_num_tokens=token_nums, + device='cuda') + lora_meta.prepare_tensors(data.token_lora_mapping) # Setup output tensors ref_out_tensor = data.ref_out_tensor - sgmv_out_tensor = data.our_out_tensor - v1_out_tensor = data.our_out_tensor.clone() + out_tensor = data.our_out_tensor.clone() with _dict_lock: - # SGMV expand kernel - _LORA_B_PTR_DICT.clear() - torch.ops.vllm.sgmv_expand( - data.inputs_tensor, - data.lora_weights, - sgmv_out_tensor, - *sgmv_meta_args, - offset_start=0, - add_inputs=add_inputs, - ) - - # V1 expand kernel + # lora_expand kernel _LORA_B_PTR_DICT.clear() - torch.ops.vllm.v1_expand(data.inputs_tensor, - data.lora_weights, - v1_out_tensor, - *v1_meta.meta_args(token_nums=token_nums), - offset_start=0, - add_inputs=add_inputs) + triton_ops.lora_expand(data.inputs_tensor, + data.lora_weights, + out_tensor, + *lora_meta.meta_args(token_nums=token_nums), + offset_start=0, + add_inputs=add_inputs) # Reference sgmv_expand_for_nslices(nslices, @@ -231,124 +204,7 @@ def check_expand_kernels(batches: int, num_loras: int, rank: int, *sgmv_meta_args, add_inputs=add_inputs) - assert_close(sgmv_out_tensor, ref_out_tensor) - assert_close(v1_out_tensor, ref_out_tensor) - - -def check_bgmv_shrink(batches: int, num_loras: int, rank: int, - hidden_size: int, dtype: torch.dtype, device: str, - scaling: float): - """ - Compare vllm.bgmv_shrink against a reference implementation. - """ - seq_length = 1 - data: PunicaTensors = generate_data( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - "shrink", - device, - ) - - torch.ops.vllm.bgmv_shrink( - data.inputs_tensor, - data.lora_weights, - data.our_out_tensor, - data.token_lora_mapping, - scaling, - ) - - bgmv_shrink( - data.inputs_tensor, - data.lora_weights, - data.ref_out_tensor, - data.token_lora_mapping, - scaling, - ) - - data.ref_out_tensor = data.ref_out_tensor.to(torch.float32) - assert_close(data.our_out_tensor, data.ref_out_tensor) - - -def check_bgmv_expand(batches: int, num_loras: int, rank: int, - hidden_size: int, dtype: torch.dtype, device: str, - add_inputs: bool): - """ - Compare vllm.bgmv_expand against a reference implementation. - """ - seq_length = 1 - data: PunicaTensors = generate_data( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - "expand", - device, - ) - - torch.ops.vllm.bgmv_expand( - data.inputs_tensor, - data.lora_weights, - data.our_out_tensor, - data.token_lora_mapping, - add_inputs=add_inputs, - ) - bgmv_expand( - data.inputs_tensor, - data.lora_weights, - data.ref_out_tensor, - data.token_lora_mapping, - add_inputs=add_inputs, - ) - assert_close(data.our_out_tensor, data.ref_out_tensor) - - -def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, dtype: torch.dtype, - device: str, add_inputs: bool): - """ - Compare vllm.bgmv_expand_slice against a reference implementation. - """ - seq_length = 1 - data: PunicaTensors = generate_data_for_expand_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - nslices, - device, - ) - - slice_offset = 0 - for index in range(nslices): - torch.ops.vllm.bgmv_expand_slice( - data.inputs_tensor, - data.lora_weights[index], - data.our_out_tensor, - data.token_lora_mapping, - slice_offset, - slice_size=hidden_size, - add_inputs=add_inputs, - ) - bgmv_expand_slice( - data.inputs_tensor, - data.lora_weights[index], - data.ref_out_tensor, - data.token_lora_mapping, - slice_offset, - slice_size=hidden_size, - add_inputs=add_inputs, - ) - - slice_offset += hidden_size - assert_close(data.our_out_tensor, data.ref_out_tensor) + assert_close(out_tensor, ref_out_tensor) # Tests @@ -490,31 +346,31 @@ def test_kernels( op_type: str, ): """ - Tests SGMV and V1 kernels. + Tests LoRA kernels. """ torch.set_default_device(device) current_platform.seed_everything(seed) if op_type == "shrink": - check_shrink_kernels(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - scaling=0.5) + check_lora_shrink_kernel(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5) else: - check_expand_kernels(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - add_inputs=True) + check_lora_expand_kernel(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True) @pytest.mark.parametrize("batches", hs_test_params['batches']) @@ -538,159 +394,28 @@ def test_kernels_hidden_size( op_type: str, ): """ - Tests SGMV and V1 kernels. + Tests SGMV and LoRA kernels. """ torch.set_default_device(device) current_platform.seed_everything(seed) if op_type == "shrink": - check_shrink_kernels(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - scaling=0.5) - else: - check_expand_kernels(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - add_inputs=True) - - -@pytest.mark.parametrize("batches", test_params['batches']) -@pytest.mark.parametrize("num_loras", test_params['num_loras']) -@pytest.mark.parametrize("rank", test_params['max_ranks']) -@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -def test_punica_bgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - dtype: torch.dtype, - device: str, - seed: int, - op_type: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - if op_type == "shrink": - check_bgmv_shrink(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - dtype=dtype, - device=device, - scaling=0.5) - else: - check_bgmv_expand(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - dtype=dtype, - device=device, - add_inputs=True) - - -@pytest.mark.parametrize("batches", hs_test_params['batches']) -@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) -@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) -@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -def test_punica_bgmv_hidden_size( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - dtype: torch.dtype, - device: str, - seed: int, - op_type: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - if op_type == "shrink": - check_bgmv_shrink(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - dtype=dtype, - device=device, - scaling=0.5) + check_lora_shrink_kernel(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5) else: - check_bgmv_expand(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - dtype=dtype, - device=device, - add_inputs=True) - - -@pytest.mark.parametrize("batches", test_params['batches']) -@pytest.mark.parametrize("num_loras", test_params['num_loras']) -@pytest.mark.parametrize("rank", test_params['max_ranks']) -@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) -@pytest.mark.parametrize("nslices", [2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("seed", SEED) -def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, - dtype: torch.dtype, device: str, - seed: int): - - torch.set_default_device(device) - current_platform.seed_everything(seed) - - check_bgmv_expand_slice(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - add_inputs=True) - - -@pytest.mark.parametrize("batches", hs_test_params['batches']) -@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) -@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) -@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) -@pytest.mark.parametrize("nslices", [2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("seed", SEED) -def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int, - rank: int, hidden_size: int, - nslices: int, - dtype: torch.dtype, - device: str, seed: int): - - torch.set_default_device(device) - current_platform.seed_everything(seed) - - check_bgmv_expand_slice(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - add_inputs=True) + check_lora_expand_kernel(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True) diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py index dc440f7327fa..acae0d972f4e 100644 --- a/vllm/lora/ops/triton_ops/__init__.py +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -1,15 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401 +from vllm.lora.ops.triton_ops.lora_expand import lora_expand +from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta +from vllm.lora.ops.triton_ops.lora_shrink import lora_shrink __all__ = [ - "bgmv_expand", - "bgmv_expand_slice", - "bgmv_shrink", - "sgmv_expand", - "sgmv_shrink", + "lora_expand", + "lora_shrink", + "LoRAKernelMeta", ] diff --git a/vllm/lora/ops/triton_ops/bgmv_expand.py b/vllm/lora/ops/triton_ops/bgmv_expand.py deleted file mode 100644 index 98510b39661a..000000000000 --- a/vllm/lora/ops/triton_ops/bgmv_expand.py +++ /dev/null @@ -1,188 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. -https://arxiv.org/abs/2310.18547 -""" - -import torch -import triton -import triton.language as tl - -from vllm.utils import direct_register_custom_op - -from .utils import get_lora_op_configs - - -@triton.jit -def _bgmv_expand_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - lora_indices, - xm_stride, - xk_stride, - l0_stride, - lora_k_stride, - lora_n_stride, - cm_stride, - cn_stride, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - SPLIT_N: tl.constexpr, - EVEN_K: tl.constexpr, - ADD_INPUTS: tl.constexpr, - CAST_TYPE: tl.constexpr, -): - """ - GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's - performance - """ - pid_sn = tl.program_id(axis=0) - cur_batch = tl.program_id(axis=1) - lora_index = tl.load(lora_indices + cur_batch) - if lora_index == -1: - return - offset_k = tl.arange(0, BLOCK_K) - offset_n = tl.arange(0, BLOCK_N) - if EVEN_K: - tiled_a = tl.load(input_ptr + cur_batch * xm_stride + - offset_k * xk_stride, ) # [BLOCK_K] - else: - tiled_a = tl.load( - input_ptr + cur_batch * xm_stride + offset_k * xk_stride, - mask=offset_k < K, - other=0, - ) # [BLOCK_K] - # N must be divisible by SPLIT_N - split_n_length = tl.cdiv(N, SPLIT_N) - if CAST_TYPE: - tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) - # sliding to next row-block - b_ptr = (lora_ptr + l0_stride * lora_index + - pid_sn * split_n_length * lora_k_stride) - c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length - for n in range(0, split_n_length, BLOCK_N): - current_n = n + offset_n - current_n_c = tl.max_contiguous(current_n, BLOCK_N) - b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] - < K) - c_mask = current_n < split_n_length - tiled_b = tl.load( - b_ptr + current_n_c[:, None] * lora_k_stride + - offset_k[None, :] * lora_n_stride, - mask=b_ptr_mask, - other=0.0, - ) # [BLOCK_N,BLOCK_K] - if ADD_INPUTS: - tiled_out = tl.load(c_ptr + current_n * cn_stride, - mask=c_mask, - other=0.0) - accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out - else: - accumulator = tl.sum(tiled_a * tiled_b, 1) - - tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) - - -@torch.inference_mode() -def _bgmv_expand( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True, -) -> None: - """ - Args: - inputs (torch.Tensor): input tensor - lora_b_weights (torch.Tensor): lora'a weight - output_tensor (torch.Tensor): output tensor - lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index - corresponding to each batch, An index of -1 means no lora should be - applied. - batches (int): batch size - add_inputs (bool, optional): Defaults to False, adds the final lora - results to the output. - """ - assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] - assert lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, - ] - assert inputs.size(1) == lora_b_weights.size(-1) - - assert inputs.is_contiguous() - assert output_tensor.is_contiguous() - - if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_b_weights.size(1) == 1 - lora_b_weights = lora_b_weights.squeeze(dim=1) - else: - assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) - assert lora_b_weights.is_contiguous() - - # TODO tuning this config - N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size - BLOCK_K = triton.next_power_of_2(K) - EVEN_K = K % BLOCK_K == 0 - ADD_INPUTS = add_inputs - CAST_TYPE = False - if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, - ]: - CAST_TYPE = True - batches = lora_indices_tensor.size(0) - config = get_lora_op_configs("expand", batches, N) - grid = lambda META: ( - META["SPLIT_N"], - batches, - ) - _bgmv_expand_kernel[grid]( - inputs, - lora_b_weights, - output_tensor, - N, - K, - lora_indices_tensor, - inputs.stride(0), - inputs.stride(1), - lora_b_weights.stride(0), - lora_b_weights.stride(1), - lora_b_weights.stride(2), - output_tensor.stride(0), - output_tensor.stride(1), - BLOCK_K=BLOCK_K, - EVEN_K=EVEN_K, - ADD_INPUTS=ADD_INPUTS, - CAST_TYPE=CAST_TYPE, - **config, - ) - return - - -def bgmv_expand_fake( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True, -) -> None: - return - - -try: - direct_register_custom_op( - op_name="bgmv_expand", - op_func=_bgmv_expand, - mutates_args=["output_tensor"], - fake_impl=bgmv_expand_fake, - ) - bgmv_expand = torch.ops.vllm.bgmv_expand - -except AttributeError: - bgmv_expand = _bgmv_expand diff --git a/vllm/lora/ops/triton_ops/bgmv_expand_slice.py b/vllm/lora/ops/triton_ops/bgmv_expand_slice.py deleted file mode 100644 index 48804123c1ea..000000000000 --- a/vllm/lora/ops/triton_ops/bgmv_expand_slice.py +++ /dev/null @@ -1,207 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. -https://arxiv.org/abs/2310.18547 -""" - -import torch -import triton -import triton.language as tl - -from vllm.utils import direct_register_custom_op - -from .utils import get_lora_op_configs - - -@triton.jit -def _bgmv_expand_slice_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - lora_indices, - xm_stride, - xk_stride, - l0_stride, - lora_k_stride, - lora_n_stride, - cm_stride, - cn_stride, - slice_offset, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - SPLIT_N: tl.constexpr, - EVEN_K: tl.constexpr, - ADD_INPUTS: tl.constexpr, - CAST_TYPE: tl.constexpr, -): - """ - GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's - performance - """ - pid_sn = tl.program_id(axis=0) - cur_batch = tl.program_id(axis=1) - lora_index = tl.load(lora_indices + cur_batch) - if lora_index == -1: - return - offset_k = tl.arange(0, BLOCK_K) - offset_n = tl.arange(0, BLOCK_N) - if EVEN_K: - tiled_a = tl.load(input_ptr + cur_batch * xm_stride + - offset_k * xk_stride, ) # [BLOCK_K] - else: - tiled_a = tl.load( - input_ptr + cur_batch * xm_stride + offset_k * xk_stride, - mask=offset_k < K, - other=0, - ) # [BLOCK_K] - # N must be divisible by SPLIT_N - split_n_length = tl.cdiv(N, SPLIT_N) - if CAST_TYPE: - tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) - # sliding to next row-block - b_ptr = (lora_ptr + l0_stride * lora_index + - pid_sn * split_n_length * lora_k_stride) - c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + - slice_offset * cn_stride) - - for n in range(0, split_n_length, BLOCK_N): - current_n = n + offset_n - b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] - < K) - c_mask = current_n < split_n_length - tiled_b = tl.load( - b_ptr + current_n[:, None] * lora_k_stride + - offset_k[None, :] * lora_n_stride, - mask=b_ptr_mask, - other=0.0, - ) # [BLOCK_N,BLOCK_K] - - if ADD_INPUTS: - # explicitly pass in other=None to tell triton that masked values - # can be uninitialized. This is OK because the later tl.store - # operation uses the same mask, eliminating the risk of garbage - # values propagating - tiled_out = tl.load(c_ptr + current_n * cn_stride, - mask=c_mask, - other=None) - accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out - else: - accumulator = tl.sum(tiled_a * tiled_b, 1) - - tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) - - -@torch.inference_mode() -def _bgmv_expand_slice( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True, -) -> None: - """ - Args: - inputs (torch.Tensor): input tensor - lora_b_weights (torch.Tensor): lora'b weight - output_tensor (torch.Tensor): output tensor - lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index - corresponding to each batch, An index of -1 means no lora should be - applied. - slice_offset (int): output_tensor's offset - slice_size (int): current output_tensor's size - batches (int): batch size - add_inputs (bool, optional): Defaults to False. - """ - assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] - assert lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, - ] - assert inputs.size(1) == lora_b_weights.size(-1) - - assert slice_size == lora_b_weights.size(-2) - assert inputs.is_contiguous() - assert output_tensor.is_contiguous() - - if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_b_weights.size(1) == 1 - lora_b_weights = lora_b_weights.squeeze(dim=1) - else: - assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) - - assert lora_b_weights.is_contiguous() - - # TODO tuning this config - - N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size - BLOCK_K = triton.next_power_of_2(K) - EVEN_K = K % BLOCK_K == 0 - ADD_INPUTS = add_inputs - CAST_TYPE = False - if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, - ]: - CAST_TYPE = True - - batches = lora_indices_tensor.size(0) - - config = get_lora_op_configs("expand", batches, N) - - grid = lambda META: ( - META["SPLIT_N"], - batches, - ) - _bgmv_expand_slice_kernel[grid]( - inputs, - lora_b_weights, - output_tensor, - N, - K, - lora_indices_tensor, - inputs.stride(0), - inputs.stride(1), - lora_b_weights.stride(0), - lora_b_weights.stride(1), - lora_b_weights.stride(2), - output_tensor.stride(0), - output_tensor.stride(1), - slice_offset, - BLOCK_K=BLOCK_K, - EVEN_K=EVEN_K, - ADD_INPUTS=ADD_INPUTS, - CAST_TYPE=CAST_TYPE, - **config, - ) - return - - -def bgmv_expand_slice_fake( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True, -) -> None: - return - - -try: - direct_register_custom_op( - op_name="bgmv_expand_slice", - op_func=_bgmv_expand_slice, - mutates_args=["output_tensor"], - fake_impl=bgmv_expand_slice_fake, - ) - bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice - -except AttributeError: - bgmv_expand_slice = _bgmv_expand_slice diff --git a/vllm/lora/ops/triton_ops/bgmv_shrink.py b/vllm/lora/ops/triton_ops/bgmv_shrink.py deleted file mode 100644 index 227a5765e56b..000000000000 --- a/vllm/lora/ops/triton_ops/bgmv_shrink.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. -https://arxiv.org/abs/2310.18547 -""" - -import torch -import triton -import triton.language as tl - -from vllm.utils import direct_register_custom_op - -from .utils import get_lora_op_configs - - -@triton.jit -def _bgmv_shrink_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - lora_indices, - scaling, - xm_stride, - xk_stride, - l0_stride, - lora_k_stride, - lora_n_stride, - cm_stride, - cn_stride, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - SPLIT_K: tl.constexpr, -): - """ - GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's - performance - """ - pid_sk = tl.program_id(axis=0) - cur_batch = tl.program_id(axis=1) - lora_index = tl.load(lora_indices + cur_batch) - if lora_index == -1: - return - - offset_n = tl.arange(0, BLOCK_N) - offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K - a_ptr = input_ptr + cur_batch * xm_stride - b_ptr = lora_ptr + l0_stride * lora_index - accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32) - for k in range(0, K, BLOCK_K * SPLIT_K): - current_k = k + offset_k - current_k_c = tl.max_contiguous(current_k, BLOCK_K) - tiled_a = tl.load( - a_ptr + current_k_c, - mask=current_k < K, - other=0.0, - ) # [BLOCK_K] - b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K) - - tiled_b = tl.load( - b_ptr + offset_n[:, None] * lora_k_stride + - current_k[None, :] * lora_n_stride, - mask=b_ptr_mask, - other=0.0, - ) # [BLOCK_N,BLOCK_K] - - accumulator += tl.sum(tiled_a * tiled_b, 1) - accumulator *= scaling - offset_cn = tl.arange(0, BLOCK_N) - c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride - c_mask = offset_cn < N - if SPLIT_K == 1: - tl.store(c_ptr, accumulator, mask=c_mask) - else: - tl.atomic_add(c_ptr, accumulator, mask=c_mask) - - -@torch.inference_mode() -def _bgmv_shrink( - inputs: torch.Tensor, - lora_a_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0, -) -> None: - """ - Args: - inputs (torch.Tensor): input tensor - lora_a_weights (torch.Tensor): lora'a weight - output_tensor (torch.Tensor): output tensor - lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index - corresponding to each batch. An index of -1 means no lora should be - applied. - batches (int): batch size - scaling (float): Scaling factor. - """ - assert inputs.dtype == lora_a_weights.dtype - assert inputs.dtype in [torch.float16, torch.bfloat16] - assert lora_a_weights.dtype in [ - torch.float16, - torch.bfloat16, - ] - assert inputs.size(1) == lora_a_weights.size(-1) - assert inputs.is_contiguous() - - if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) - assert lora_a_weights.size(1) == 1 - lora_a_weights = lora_a_weights.squeeze(dim=1) - else: - assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) - assert lora_a_weights.is_contiguous() - assert output_tensor.is_contiguous() - # TODO tuning this config - batches = lora_indices_tensor.size(0) - N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank - BLOCK_N = triton.next_power_of_2(N) - # First try to load optimal config from the file - config = get_lora_op_configs("bgmv_shrink", batches, K) - - grid = lambda META: ( - META["SPLIT_K"], - batches, - ) - _bgmv_shrink_kernel[grid]( - inputs, - lora_a_weights, - output_tensor, - N, - K, - lora_indices_tensor, - scaling, - inputs.stride(0), - inputs.stride(1), - lora_a_weights.stride(0), - lora_a_weights.stride(1), - lora_a_weights.stride(2), - output_tensor.stride(0), - output_tensor.stride(1), - BLOCK_N=BLOCK_N, - **config, - ) - return - - -def bgmv_shrink_fake( - inputs: torch.Tensor, - lora_a_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0, -) -> None: - return - - -try: - direct_register_custom_op( - op_name="bgmv_shrink", - op_func=_bgmv_shrink, - mutates_args=["output_tensor"], - fake_impl=bgmv_shrink_fake, - ) - bgmv_shrink = torch.ops.vllm.bgmv_shrink - -except AttributeError: - bgmv_shrink = _bgmv_shrink diff --git a/vllm/lora/ops/triton_ops/v1/v1_expand.py b/vllm/lora/ops/triton_ops/lora_expand.py similarity index 96% rename from vllm/lora/ops/triton_ops/v1/v1_expand.py rename to vllm/lora/ops/triton_ops/lora_expand.py index 20c7f8f4c7ff..b47e491ad7ed 100644 --- a/vllm/lora/ops/triton_ops/v1/v1_expand.py +++ b/vllm/lora/ops/triton_ops/lora_expand.py @@ -18,7 +18,7 @@ @triton.jit -def _v1_expand_kernel( +def _lora_expand_kernel( input_ptr, lora_ptr, out_ptr, @@ -125,7 +125,7 @@ def _v1_expand_kernel( @torch.inference_mode() -def _v1_expand( +def _lora_expand( inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] lora_b_weights: List[ torch.Tensor], # shape [num_lora, hidden_size, lora_rank] @@ -216,7 +216,7 @@ def _v1_expand( MAX_LORAS, ) - _v1_expand_kernel[grid]( + _lora_expand_kernel[grid]( inputs, lora_ptr_tensor, output_tensor, @@ -254,7 +254,7 @@ def _v1_expand( return -def _v1_expand_fake( +def _lora_expand_fake( inputs: torch.Tensor, lora_b_weights: List[torch.Tensor], output_tensor: torch.Tensor, @@ -271,12 +271,12 @@ def _v1_expand_fake( try: direct_register_custom_op( - op_name="v1_expand", - op_func=_v1_expand, + op_name="lora_expand", + op_func=_lora_expand, mutates_args=["output_tensor"], - fake_impl=_v1_expand_fake, + fake_impl=_lora_expand_fake, ) - v1_expand = torch.ops.vllm.v1_expand + lora_expand = torch.ops.vllm.lora_expand except AttributeError: - v1_expand = _v1_expand + lora_expand = _lora_expand diff --git a/vllm/lora/ops/triton_ops/v1/v1_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py similarity index 94% rename from vllm/lora/ops/triton_ops/v1/v1_kernel_metadata.py rename to vllm/lora/ops/triton_ops/lora_kernel_metadata.py index 57b4dd7a9027..2add1177e84c 100644 --- a/vllm/lora/ops/triton_ops/v1/v1_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ -V1 LoRA kernels metadata preparation utilities. +LoRA kernels metadata preparation utilities. """ from dataclasses import dataclass @@ -10,7 +10,7 @@ @dataclass -class V1KernelMeta: +class LoRAKernelMeta: token_lora_mapping: torch.Tensor token_indices_sorted_by_lora_ids: torch.Tensor active_lora_ids: torch.Tensor @@ -19,7 +19,7 @@ class V1KernelMeta: @staticmethod def make(max_loras: int, max_num_tokens: int, - device: Union[torch.device, str]) -> "V1KernelMeta": + device: Union[torch.device, str]) -> "LoRAKernelMeta": token_lora_mapping = torch.empty(max_num_tokens, dtype=torch.int32, @@ -47,7 +47,7 @@ def make(max_loras: int, max_num_tokens: int, lora_token_start_loc = torch.zeros(max_loras + 2, dtype=torch.int32, device=device) - return V1KernelMeta( + return LoRAKernelMeta( token_lora_mapping=token_lora_mapping, token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, active_lora_ids=active_lora_ids, @@ -105,7 +105,7 @@ def meta_args( This function returns the kernel metadata required for the current forward pass execution of the kernel. The function returns all the metadata required by the kernel, in order, as a tuple, so it can be - unpacked directly during the v1_shrink/v1_expand function call. + unpacked directly during the lora_shrink/lora_expand function call. Args: token_nums (int): Number of input tokens in the current forward diff --git a/vllm/lora/ops/triton_ops/v1/v1_shrink.py b/vllm/lora/ops/triton_ops/lora_shrink.py similarity index 88% rename from vllm/lora/ops/triton_ops/v1/v1_shrink.py rename to vllm/lora/ops/triton_ops/lora_shrink.py index 39affd189220..a97c50c44f47 100644 --- a/vllm/lora/ops/triton_ops/v1/v1_shrink.py +++ b/vllm/lora/ops/triton_ops/lora_shrink.py @@ -18,15 +18,15 @@ @triton.jit -def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, - token_indices_sorted_by_lora_ids, num_tokens_per_lora, - lora_token_start_loc, lora_ids, scaling, input_d0_stride, - input_d1_stride, lora_d0_stride, lora_d1_stride, - lora_d2_stride, output_d0_stride, output_d1_stride, - output_d2_stride, BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, - SLICE_NUM: tl.constexpr): +def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, + token_indices_sorted_by_lora_ids, num_tokens_per_lora, + lora_token_start_loc, lora_ids, scaling, + input_d0_stride, input_d1_stride, lora_d0_stride, + lora_d1_stride, lora_d2_stride, output_d0_stride, + output_d1_stride, output_d2_stride, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr): cta_n_num = tl.cdiv(N, BLOCK_N) cta_m_num = tl.cdiv(M, BLOCK_M) @@ -96,7 +96,7 @@ def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, @torch.inference_mode() -def _v1_shrink( +def _lora_shrink( inputs: torch.Tensor, # shape [num_tokens, hidden_size] lora_a_weights: List[ torch.Tensor], # shape [num_loras, lora_rank, hidden_size] @@ -174,7 +174,7 @@ def _v1_shrink( MAX_LORAS, ) - _v1_shrink_kernel[grid]( + _lora_shrink_kernel[grid]( inputs, lora_ptr_tensor, output_tensor, @@ -209,7 +209,7 @@ def _v1_shrink( return -def _v1_shrink_fake( +def _lora_shrink_fake( inputs: torch.Tensor, lora_a_weights: List[torch.Tensor], output_tensor: torch.Tensor, @@ -225,12 +225,12 @@ def _v1_shrink_fake( try: direct_register_custom_op( - op_name="v1_shrink", - op_func=_v1_shrink, + op_name="lora_shrink", + op_func=_lora_shrink, mutates_args=["output_tensor"], - fake_impl=_v1_shrink_fake, + fake_impl=_lora_shrink_fake, ) - v1_shrink = torch.ops.vllm.v1_shrink + lora_shrink = torch.ops.vllm.lora_shrink except AttributeError: - v1_shrink = _v1_shrink + lora_shrink = _lora_shrink diff --git a/vllm/lora/ops/triton_ops/sgmv_expand.py b/vllm/lora/ops/triton_ops/sgmv_expand.py deleted file mode 100644 index 6aa3eafaba4c..000000000000 --- a/vllm/lora/ops/triton_ops/sgmv_expand.py +++ /dev/null @@ -1,249 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. -https://arxiv.org/abs/2310.18547 -""" - -from typing import List - -import torch -import triton -import triton.language as tl - -from vllm.utils import direct_register_custom_op - -from .kernel_utils import do_expand_kernel -from .utils import _get_lora_b_ptr - - -@triton.jit -def _sgmv_expand_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - b_seq_start_loc, - seq_lens, - lora_indices, - slice_start_loc, - input_d0_stride, - input_d1_stride, - input_d2_stride, # 1 - ls_d0_ptr, - ls_d1_ptr, - ls_d2_ptr, # 1 - output_d0_stride, - output_d1_stride, # 1 - output_hs_ptr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - ADD_INPUTS: tl.constexpr, - CAST_TYPE: tl.constexpr, - SLICE_NUM: tl.constexpr, - SAME_STRIDE: tl.constexpr): - """ - - Similar to the 'sgmv_expand' operator, but with an added parameter - 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator - might be that in the future, we could implement a fusion operator to - achieve the current functionality instead of having to call it multiple - times. - """ - pid = tl.program_id(axis=0) - cur_batch = tl.program_id(axis=1) - slice_id = tl.program_id(axis=2) - cta_n_num = tl.cdiv(N, BLOCK_N) - # When the output dimensions of each slice are the same,cur_n=N, otherwise - # cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's - # qkv linear. - curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) - pid_m = pid // cta_n_num - pid_n = pid % cta_n_num - - M = tl.load(seq_lens + cur_batch) - if pid_m * BLOCK_M >= M: - return - if pid_n * BLOCK_N >= curr_N: - return - lora_index = tl.load(lora_indices + cur_batch) - if lora_index == -1: - return - - m_offset = tl.load(b_seq_start_loc + cur_batch) - - cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M)) - cta_m_offset = m_offset + (pid_m * BLOCK_M) - offset_m = tl.arange(0, BLOCK_M) - ram = cta_m_offset + tl.max_contiguous( - tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M) - do_expand_kernel( - pid_n, - lora_index, - slice_id, - input_ptr, - lora_ptr, - out_ptr, - curr_N, - K, - cta_m_len, - ram, # array identifying the rows of Input ptr to operate on - slice_start_loc, - # input ptr strides - input_d0_stride, - input_d1_stride, - input_d2_stride, - # lora ptr strides - ls_d0_ptr, - ls_d1_ptr, - ls_d2_ptr, - # out ptr strides - output_d0_stride, - output_d1_stride, - # constants - BLOCK_M, - BLOCK_N, - BLOCK_K, - SAME_STRIDE, - SLICE_NUM, - EVEN_K, - CAST_TYPE, - ADD_INPUTS, - ) - - -@torch.inference_mode() -def _sgmv_expand( - inputs: torch.Tensor, - lora_b_weights: List[torch.Tensor], - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - offset_start: int = 0, - add_inputs: bool = False, -) -> None: - """ - Args: - inputs (torch.Tensor): input tensor - lora_b_weights (List[torch.Tensor]): lora'b weight - output_tensor (torch.Tensor): output tensor - b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative - sequence lengths of the sequences in the batch, used to index - into sequence. E.g., if the sequence length is [4, 6], it is - [0, 4]. - seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence - length of the sequences in the batch. - lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index - corresponding to each batch. An index of -1 means no lora should be - applied. - batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences in the - batch. - token_nums (int): The token numbers in the batch. Used to verify if the - token numbers in the inputs matches the one in the metadata. - offset_start (int, optional): Offset start for output_tensor. - Defaults to 0. - add_inputs (bool, optional): Whether to add the input tensor to the - output tensor. Defaults to False. - """ - assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] - for weight in lora_b_weights: - assert weight.dtype in [torch.float16, torch.bfloat16] - - assert inputs.size(1) == token_nums - assert inputs.size(0) == len(lora_b_weights) - - assert b_seq_start_loc.size(0) == batches - assert lora_indices_tensor.size(0) == batches - assert output_tensor.is_contiguous() - (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, - lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor, - same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start, - b_seq_start_loc.device) - - # TODO tuning this config - K = lora_b_weights[0].shape[-1] # K= rank - - BLOCK_M = 64 - BLOCK_N = 128 - BLOCK_K = 16 - EVEN_K = K % BLOCK_K == 0 - ADD_INPUTS = add_inputs - CAST_TYPE = False - - if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [ - torch.float16, - torch.bfloat16, - ]: - CAST_TYPE = True - grid = ( - triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), - batches, - len(lora_b_weights), - ) - _sgmv_expand_kernel[grid]( - inputs, - lora_ptr_tensor, - output_tensor, - MAX_N, - K, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - slice_start_tensor, - inputs.stride(0), - inputs.stride(1), - inputs.stride(2), - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, - output_tensor.stride(0), - output_tensor.stride(1), - hidden_sizes_tensor, - BLOCK_M, - BLOCK_N, - BLOCK_K, - EVEN_K, - ADD_INPUTS, - CAST_TYPE, - len(lora_b_weights), - same_stride, - ) - return - - -def _sgmv_expand_fake( - inputs: torch.Tensor, - lora_b_weights: List[torch.Tensor], - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - offset_start: int = 0, - add_inputs: bool = False, -) -> None: - return - - -try: - direct_register_custom_op( - op_name="sgmv_expand", - op_func=_sgmv_expand, - mutates_args=["output_tensor"], - fake_impl=_sgmv_expand_fake, - ) - sgmv_expand = torch.ops.vllm.sgmv_expand - -except AttributeError: - sgmv_expand = _sgmv_expand diff --git a/vllm/lora/ops/triton_ops/sgmv_shrink.py b/vllm/lora/ops/triton_ops/sgmv_shrink.py deleted file mode 100644 index b8ed0b020f9a..000000000000 --- a/vllm/lora/ops/triton_ops/sgmv_shrink.py +++ /dev/null @@ -1,224 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. -https://arxiv.org/abs/2310.18547 -""" - -from typing import List - -import torch -import triton -import triton.language as tl - -from vllm.utils import direct_register_custom_op - -from .kernel_utils import do_shrink_kernel -from .utils import _get_lora_a_ptr - - -@triton.jit -def _sgmv_shrink_kernel( - input_ptr, - lora_ptr, #1-3 - out_ptr, - N, - K, - b_seq_start_loc, - seq_lens, - lora_indices, - scaling, - input_d0_stride, - input_d1_stride, # 1 - lora_d0_stride, - lora_d1_stride, - lora_d2_stride, # 1 - output_d0_stride, - output_d1_stride, - output_d2_stride, # 1 - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - SPLIT_K: tl.constexpr, - SLICE_NUM: tl.constexpr): - """ - The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. - The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, - introducing SPLIT-K can improve performance - """ - pid = tl.program_id(axis=0) - pid_mix = tl.program_id(axis=1) - cur_batch = tl.program_id(axis=2) - cta_n_num = tl.cdiv(N, BLOCK_N) - pid_m = pid // cta_n_num - pid_n = pid % cta_n_num - if SLICE_NUM == 1: - slice_id: tl.constexpr = 0 - pid_sk = tl.program_id(axis=1) - else: - pid_mix = tl.program_id(axis=1) - slice_id = pid_mix // SPLIT_K - pid_sk = pid_mix % SPLIT_K - - M = tl.load(seq_lens + cur_batch) - if pid_m * BLOCK_M >= M: - return - lora_index = tl.load(lora_indices + cur_batch) - if lora_index == -1: - return - - m_offset = tl.load(b_seq_start_loc + cur_batch) - - cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M)) - cta_m_offset = m_offset + (pid_m * BLOCK_M) - offset_m = tl.arange(0, BLOCK_M) - ram = cta_m_offset + tl.max_contiguous( - tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M) - - do_shrink_kernel( - pid_n, - pid_sk, - slice_id, - lora_index, - input_ptr, - lora_ptr, - out_ptr, - N, - K, - cta_m_len, - ram, - # input strides - input_d0_stride, - input_d1_stride, - # lora strides - lora_d0_stride, - lora_d1_stride, - lora_d2_stride, - # output strides - output_d0_stride, - output_d1_stride, - output_d2_stride, - scaling, - BLOCK_M, - BLOCK_N, - BLOCK_K, - EVEN_K, - SPLIT_K, - SLICE_NUM) - - -@torch.inference_mode() -def _sgmv_shrink( - inputs: torch.Tensor, - lora_a_weights: List[torch.Tensor], - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - scaling: float, -) -> None: - """ - Args: - inputs (torch.Tensor): input tensor - lora_a_weights (List[torch.Tensor]): lora'a weight - output_tensor (torch.Tensor): output tensor - b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative - sequence lengths of the sequences in the batch, used to index - into sequence. E.g., if the sequence length is [4, 6], it is - [0, 4]. - seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence - length of the sequences in the batch. - lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index - corresponding to each batch. An index of -1 means no lora should be - applied. - batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences in the - batch. - token_nums (int): The token numbers in the batch. Used to verify if the - token numbers in the inputs matches the one in the metadata. - scaling (float): Scaling factor. - """ - assert inputs.dtype == lora_a_weights[0].dtype - assert inputs.dtype in [torch.float16, torch.bfloat16] - for weight in lora_a_weights: - assert weight.dtype in [torch.float16, torch.bfloat16] - - assert inputs.size(0) == token_nums - assert inputs.size(1) == lora_a_weights[0].size(-1) - assert b_seq_start_loc.size(0) == batches - assert lora_indices_tensor.size(0) == batches - assert inputs.is_contiguous() - assert output_tensor.is_contiguous() - (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, - lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device) - # TODO tuning this config - N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank - BLOCK_M = 32 - BLOCK_N = 16 - BLOCK_K = 32 - SPLIT_K = 8 - EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 - grid = ( - triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), - SPLIT_K * len(lora_a_weights), - batches, - ) - _sgmv_shrink_kernel[grid]( - inputs, - lora_ptr_tensor, - output_tensor, - N, - K, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - scaling, - inputs.stride(0), - inputs.stride(1), - lora_strides_d0, - lora_strides_d1, - lora_strides_d2, - output_tensor.stride(0), - output_tensor.stride(1), - output_tensor.stride(2), - BLOCK_M, - BLOCK_N, - BLOCK_K, - EVEN_K, - SPLIT_K, - len(lora_a_weights), - ) - return - - -def sgmv_shrink_fake( - inputs: torch.Tensor, - lora_a_weights: List[torch.Tensor], - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - scaling: float, -) -> None: - return - - -try: - direct_register_custom_op( - op_name="sgmv_shrink", - op_func=_sgmv_shrink, - mutates_args=["output_tensor"], - fake_impl=sgmv_shrink_fake, - ) - sgmv_shrink = torch.ops.vllm.sgmv_shrink - -except AttributeError: - sgmv_shrink = _sgmv_shrink diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index b52a842cdaf9..f779bbccd31a 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -1,55 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -import functools from typing import Dict, List, Tuple import torch - -@functools.lru_cache -def _get_op_configs(op_type: str, batch: int, hidden_size: int): - # TODO: add optimal configurations - return None - - -def _check_divisibility(hidden_size: int): - # The bgmv_expand kernel requires that the hidden_size be divisible by - # the number below. - divisibility = [2, 4, 8, 16, 32, 64] - divisibility.sort(reverse=True) - for div in divisibility: - if hidden_size % div == 0: - return div - # hidden_size is an odd number - return 1 - - -def _get_default_config(op_type: str, batch: int, hidden_size: int): - if op_type == "expand": - return { - "BLOCK_N": 256, - "SPLIT_N": _check_divisibility(hidden_size), - "num_warps": 8 - } - else: - return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8} - - -def get_lora_op_configs(op_type: str, batch: int, - hidden_size: int) -> Dict[str, int]: - """Inspired by `fused_moe_kernel` - The return value will be a dictionary mapping an irregular grid of batch - sizes and hidden_size to configurations of the bgmv-related kernel. - NOTE: It currently only supports the default configuration. We plan to - generate optimal configurations for different hardware in the future using - scripts similar to `benchmark_moe.py`. - """ - config = _get_op_configs(op_type, batch, hidden_size) - if not config: - config = _get_default_config(op_type, batch, hidden_size) - return config - - _LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} _LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} diff --git a/vllm/lora/ops/triton_ops/v1/__init__.py b/vllm/lora/ops/triton_ops/v1/__init__.py deleted file mode 100644 index 1d2c46f4e9fa..000000000000 --- a/vllm/lora/ops/triton_ops/v1/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm.lora.ops.triton_ops.v1.v1_expand import v1_expand -from vllm.lora.ops.triton_ops.v1.v1_kernel_metadata import V1KernelMeta -from vllm.lora.ops.triton_ops.v1.v1_shrink import v1_shrink - -__all__ = [ - "v1_expand", - "v1_shrink", - "V1KernelMeta", -] \ No newline at end of file diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 19a94eea910c..eb6f5b1b488c 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -10,20 +10,12 @@ import torch -import vllm.envs as env from vllm.lora.layers import LoRAMapping from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - if env.VLLM_USE_V1: - from vllm.lora.ops.triton_ops.v1 import (V1KernelMeta, v1_expand, - v1_shrink) - else: - from vllm.lora.ops.triton_ops import bgmv_expand - from vllm.lora.ops.triton_ops import bgmv_expand_slice - from vllm.lora.ops.triton_ops import bgmv_shrink - from vllm.lora.ops.triton_ops import sgmv_expand - from vllm.lora.ops.triton_ops import sgmv_shrink + from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, + lora_shrink) from .punica_base import PunicaWrapperBase @@ -32,57 +24,8 @@ from vllm.lora.models import LongContextLoRAContext -class V1KernelMixin: - - def _v1_make_metadata(self, max_loras: int, max_num_batched_tokens: int, - max_batches: int, device: Union[torch.device, str]): - self.token_mapping_v1_meta = V1KernelMeta.make(max_loras, - max_num_batched_tokens, - device=device) - self.prompt_mapping_v1_meta = V1KernelMeta.make(max_loras, - max_batches, - device=device) - - def _v1_prepare_metadata_tensors(self, token_lora_indices: torch.Tensor, - sampler_indices: torch.Tensor): - self.token_mapping_v1_meta.prepare_tensors(token_lora_indices) - self.prompt_mapping_v1_meta.prepare_tensors(sampler_indices) - - def _v1_apply_shrink( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: Tuple[torch.Tensor, ...], - scale: float, - ): - v1_shrink( - x, - w_t_all, - y, - *self.token_mapping_v1_meta.meta_args(x.size(0)), - scale, - ) - - def _v1_apply_expand( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: Tuple[torch.Tensor, ...], - offset_start: int, - add_inputs: bool, - ): - v1_expand( - x, - w_t_all, - y, - *self.token_mapping_v1_meta.meta_args(x.size(0)), - offset_start=offset_start, - add_inputs=add_inputs, - ) - - @final -class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin): +class PunicaWrapperGPU(PunicaWrapperBase): """ PunicaWrapperGPU is designed to manage and provide metadata for the punica kernel. The main function is to maintain the state information for @@ -96,9 +39,12 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self.max_loras = kwargs['max_loras'] - if env.VLLM_USE_V1: - self._v1_make_metadata(self.max_loras, max_num_batched_tokens, - max_batches, device) + self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras, + max_num_batched_tokens, + device=device) + self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras, + max_batches, + device=device) def update_metadata( self, @@ -110,83 +56,18 @@ def update_metadata( long_lora_context: Optional["LongContextLoRAContext"] = None, **kwargs): - if env.VLLM_USE_V1: - self.is_prefill = mapping.is_prefill - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) - self._v1_prepare_metadata_tensors(self.token_lora_indices, - self.sampler_indices) - else: - # Forward to base class update_metadata - PunicaWrapperBase.update_metadata(self, mapping, lora_index_to_id, - max_loras, vocab_size, - extra_vocab_size, - long_lora_context, **kwargs) - - def _apply_shrink_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: Tuple[torch.Tensor, ...], - scale: float, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_shrink( - x, - w_t_all, - y, - *self.prefill_metadata, - scale, - ) + self.is_prefill = mapping.is_prefill + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) - def _apply_shrink_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - - def _apply_expand_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: Tuple[torch.Tensor, ...], - offset_start: int, - add_inputs: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return - - sgmv_expand( - x, - w_t_all, - y, - *self.prefill_metadata, - offset_start=offset_start, - add_inputs=add_inputs, - ) + # Prepare cuda kernel metadata tensors + self.token_mapping_meta.prepare_tensors(self.token_lora_indices) + self.prompt_mapping_meta.prepare_tensors(self.sampler_indices) - def _apply_expand_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_inputs: bool, - ): - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_inputs) - - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, **kwargs): + def add_shrink(self, y: torch.Tensor, x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, + ...], scale: float, **kwargs): """ Performs GEMM for multiple slices of lora_a. When `is_prefill is` true, it indicates that it is currently the @@ -199,33 +80,24 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y[i] += (x @ lora_a_stacked[i]) * scale Args: - y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + y (torch.Tensor): Output tensors x (torch.Tensor): Input tensor lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ x = x.view(-1, x.shape[-1]) - - if env.VLLM_USE_V1: - self._v1_apply_shrink(y, x, lora_a_stacked, scale) # type: ignore - else: - if self.is_prefill: - # NOTE fused kernel - self._apply_shrink_prefill( - y, # type: ignore - x, - lora_a_stacked, - scale) - else: - # TODO fuse these kernels - for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink_decode(y[slice_idx], x, - lora_a_stacked[slice_idx], scale) + lora_shrink( + x, + lora_a_stacked, + y, + *self.token_mapping_meta.meta_args(x.size(0)), + scale, + ) def add_expand(self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_b_stacked: Tuple[torch.Tensor, ...], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], output_slices: Tuple[int, ...], @@ -244,7 +116,7 @@ def add_expand(self, Args: y (torch.Tensor): Output tensor. - x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + x (torch.Tensor): Input tensors lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): bias's weight @@ -259,37 +131,19 @@ def add_expand(self, self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked) - if env.VLLM_USE_V1: - # TODO (varun): Profile with add_inputs = False. i.e. move the - # addition out of the kernel - self._v1_apply_expand( - y, - x, # type: ignore - lora_b_stacked, - offset_start, - add_inputs=True) - else: - - if self.is_prefill: - # NOTE fused kernel - self._apply_expand_prefill( - y, - x, # type: ignore - lora_b_stacked, - offset_start, - add_inputs=True) - else: - # TODO fuse these kernels - for slice_idx in range(len(lora_b_stacked)): - self._apply_expand_decode( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_start, - output_slices[slice_idx], - add_inputs=add_inputs, - ) - offset_start += output_slices[slice_idx] + assert x.ndim == 3 + assert x.size(0) == len(output_slices) + num_tokens = x.size(1) # first dimension is the num slices + + lora_expand( + x, + lora_b_stacked, + y, + *self.token_mapping_meta.meta_args(num_tokens), + offset_start=offset_start, + add_inputs=True, + ) + y = y.view_as(y_org) def add_lora_embedding(self, @@ -311,24 +165,14 @@ def add_lora_embedding(self, add_inputs (bool): Default to True. """ - if env.VLLM_USE_V1: - self._v1_apply_expand(y, - x.unsqueeze(dim=0), (lora_b_stacked, ), - offset_start=0, - add_inputs=add_inputs) - else: - if self.is_prefill: - sgmv_expand( - x.unsqueeze(dim=0), - (lora_b_stacked, ), - y, - *self.prefill_metadata, - offset_start=0, - add_inputs=add_inputs, - ) - else: - bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices, - add_inputs) + lora_expand( + x.unsqueeze(dim=0), + (lora_b_stacked, ), + y, + *self.token_mapping_meta.meta_args(x.size(0)), + offset_start=0, + add_inputs=add_inputs, + ) def add_lora_linear(self, y: torch.Tensor, @@ -339,7 +183,7 @@ def add_lora_linear(self, scale: float, output_slices: Tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[torch.Tensor] = None, **kwargs) -> None: """ Applicable to linear-related lora. @@ -361,7 +205,7 @@ def add_lora_linear(self, lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (Tuple[int, ...]): Every slice's size. - buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + buffer (Optional[torch.Tensor]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) @@ -431,21 +275,11 @@ def add_lora_logits(self, dtype=torch.float32, device=x.device) - if env.VLLM_USE_V1: - v1_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0), - *self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale) - - v1_expand(buffer.unsqueeze(dim=0), [lora_b_stacked], - y, - *self.prompt_mapping_v1_meta.meta_args(buffer.size(0)), - add_inputs=True) - else: - - # V0 LogitsProcessorWithLoRA always using bgmv. - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) + lora_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0), + *self.prompt_mapping_meta.meta_args(x.size(0)), scale) + + lora_expand(buffer.unsqueeze(dim=0), [lora_b_stacked], + y, + *self.prompt_mapping_meta.meta_args(buffer.size(0)), + add_inputs=True) y = y.view_as(y_org) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 2814f0fda7c5..a8a19e0e6206 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -62,9 +62,10 @@ def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...], if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - # Set is_prefill to True, so we always use the SGMV kernels. - # For cuda platforms, we have specialized triton kernels, and - # the cuda path ignores `is_prefill`. + # Set is_prefill to True, so we always use the SGMV kernels on + # non-cuda platforms. + # On cuda platforms we use the same kernels for prefill and + # decode and this flag is generally ignored. lora_mapping = LoRAMapping(token_lora_mapping, prompt_lora_mapping, is_prefill=True)