From 5600ce12603726eb969079d89212f140fbc0406b Mon Sep 17 00:00:00 2001 From: Ozzie Moreno Date: Wed, 30 Apr 2025 11:58:05 -0700 Subject: [PATCH] [Bugfix] Adding maxnreg to lora expand/shrink kernel definition --- vllm/lora/ops/triton_ops/lora_expand_op.py | 3 ++- vllm/lora/ops/triton_ops/lora_shrink_op.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index eacc6fb46ebd..a41fde588a8e 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -46,7 +46,8 @@ def _lora_expand_kernel( ADD_INPUTS: tl.constexpr, CAST_TYPE: tl.constexpr, SLICE_NUM: tl.constexpr, - SAME_STRIDE: tl.constexpr): + SAME_STRIDE: tl.constexpr, + maxnreg: tl.constexpr): cta_n_num = tl.cdiv(N, BLOCK_N) cta_m_num = tl.cdiv(M, BLOCK_M) diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 82331939d859..60e0cb2d5770 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -26,7 +26,8 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, output_d1_stride, output_d2_stride, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, - SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr): + SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr, + maxnreg: tl.constexpr): cta_n_num = tl.cdiv(N, BLOCK_N) cta_m_num = tl.cdiv(M, BLOCK_M)