@@ -16,24 +16,33 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
1616 f(in_T, out_T, W_T, narrow, 512 ) \
1717 f(in_T, out_T, W_T, narrow, 640 ) \
1818 f(in_T, out_T, W_T, narrow, 768 ) \
19+ f(in_T, out_T, W_T, narrow, 896 ) \
1920 f(in_T, out_T, W_T, narrow, 1024 ) \
2021 f(in_T, out_T, W_T, narrow, 1152 ) \
22+ f(in_T, out_T, W_T, narrow, 1216 ) \
2123 f(in_T, out_T, W_T, narrow, 1280 ) \
2224 f(in_T, out_T, W_T, narrow, 1536 ) \
2325 f(in_T, out_T, W_T, narrow, 1664 ) \
2426 f(in_T, out_T, W_T, narrow, 1728 ) \
2527 f(in_T, out_T, W_T, narrow, 1792 ) \
2628 f(in_T, out_T, W_T, narrow, 2048 ) \
29+ f(in_T, out_T, W_T, narrow, 2240 ) \
2730 f(in_T, out_T, W_T, narrow, 2304 ) \
31+ f(in_T, out_T, W_T, narrow, 2368 ) \
32+ f(in_T, out_T, W_T, narrow, 2432 ) \
2833 f(in_T, out_T, W_T, narrow, 2560 ) \
2934 f(in_T, out_T, W_T, narrow, 2752 ) \
3035 f(in_T, out_T, W_T, narrow, 2816 ) \
3136 f(in_T, out_T, W_T, narrow, 3072 ) \
3237 f(in_T, out_T, W_T, narrow, 3328 ) \
3338 f(in_T, out_T, W_T, narrow, 3456 ) \
3439 f(in_T, out_T, W_T, narrow, 3584 ) \
40+ f(in_T, out_T, W_T, narrow, 3712 ) \
3541 f(in_T, out_T, W_T, narrow, 4096 ) \
42+ f(in_T, out_T, W_T, narrow, 4480 ) \
3643 f(in_T, out_T, W_T, narrow, 4608 ) \
44+ f(in_T, out_T, W_T, narrow, 4736 ) \
45+ f(in_T, out_T, W_T, narrow, 4864 ) \
3746 f(in_T, out_T, W_T, narrow, 5120 ) \
3847 f(in_T, out_T, W_T, narrow, 5504 ) \
3948 f(in_T, out_T, W_T, narrow, 5632 ) \
@@ -43,24 +52,32 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
4352 f(in_T, out_T, W_T, narrow, 6848 ) \
4453 f(in_T, out_T, W_T, narrow, 6912 ) \
4554 f(in_T, out_T, W_T, narrow, 7168 ) \
55+ f(in_T, out_T, W_T, narrow, 7424 ) \
4656 f(in_T, out_T, W_T, narrow, 8192 ) \
57+ f(in_T, out_T, W_T, narrow, 8960 ) \
4758 f(in_T, out_T, W_T, narrow, 9216 ) \
59+ f(in_T, out_T, W_T, narrow, 9472 ) \
4860 f(in_T, out_T, W_T, narrow, 10240 ) \
4961 f(in_T, out_T, W_T, narrow, 11008 ) \
5062 f(in_T, out_T, W_T, narrow, 11264 ) \
5163 f(in_T, out_T, W_T, narrow, 12288 ) \
5264 f(in_T, out_T, W_T, narrow, 13696 ) \
5365 f(in_T, out_T, W_T, narrow, 13824 ) \
5466 f(in_T, out_T, W_T, narrow, 14336 ) \
67+ f(in_T, out_T, W_T, narrow, 14784 ) \
68+ f(in_T, out_T, W_T, narrow, 14848 ) \
5569 f(in_T, out_T, W_T, narrow, 15360 ) \
5670 f(in_T, out_T, W_T, narrow, 16384 ) \
71+ f(in_T, out_T, W_T, narrow, 18944 ) \
5772 f(in_T, out_T, W_T, narrow, 20480 ) \
5873 f(in_T, out_T, W_T, narrow, 22016 ) \
5974 f(in_T, out_T, W_T, narrow, 22528 ) \
6075 f(in_T, out_T, W_T, narrow, 24576 ) \
6176 f(in_T, out_T, W_T, narrow, 27392 ) \
6277 f(in_T, out_T, W_T, narrow, 27648 ) \
6378 f(in_T, out_T, W_T, narrow, 28672 ) \
79+ f(in_T, out_T, W_T, narrow, 29568 ) \
80+ f(in_T, out_T, W_T, narrow, 29696 ) \
6481 f(in_T, out_T, W_T, narrow, 32000 ) \
6582 f(in_T, out_T, W_T, narrow, 32256 ) \
6683 f(in_T, out_T, W_T, narrow, 32512 ) \
@@ -85,34 +102,43 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
85102// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
86103// and vllm/tests/lora/test_punica.py
87104
88- // Used for defining kernels going from the variety of
105+ // Used for defining kernels going from the variety of
89106// dim in to the narrow dim out
90- // Using it for the fully sharded column
107+ // Using it for the fully sharded column
91108 // parallel LoRA A which splits the rank dim
92109#define FOR_INST_BGMV_NARROW (f, in_T, out_T, W_T, narrow ) \
93110 f (in_T, out_T, W_T, 128 , narrow) \
94111 f(in_T, out_T, W_T, 256 , narrow) \
95112 f(in_T, out_T, W_T, 512 , narrow) \
96113 f(in_T, out_T, W_T, 640 , narrow) \
97114 f(in_T, out_T, W_T, 768 , narrow) \
115+ f(in_T, out_T, W_T, 896 , narrow) \
98116 f(in_T, out_T, W_T, 1024 , narrow) \
99117 f(in_T, out_T, W_T, 1152 , narrow) \
118+ f(in_T, out_T, W_T, 1216 , narrow) \
100119 f(in_T, out_T, W_T, 1280 , narrow) \
101120 f(in_T, out_T, W_T, 1536 , narrow) \
102121 f(in_T, out_T, W_T, 1664 , narrow) \
103122 f(in_T, out_T, W_T, 1728 , narrow) \
104123 f(in_T, out_T, W_T, 1792 , narrow) \
105124 f(in_T, out_T, W_T, 2048 , narrow) \
125+ f(in_T, out_T, W_T, 2240 , narrow) \
106126 f(in_T, out_T, W_T, 2304 , narrow) \
127+ f(in_T, out_T, W_T, 2368 , narrow) \
128+ f(in_T, out_T, W_T, 2432 , narrow) \
107129 f(in_T, out_T, W_T, 2560 , narrow) \
108130 f(in_T, out_T, W_T, 2752 , narrow) \
109131 f(in_T, out_T, W_T, 2816 , narrow) \
110132 f(in_T, out_T, W_T, 3072 , narrow) \
111133 f(in_T, out_T, W_T, 3328 , narrow) \
112134 f(in_T, out_T, W_T, 3456 , narrow) \
113135 f(in_T, out_T, W_T, 3584 , narrow) \
136+ f(in_T, out_T, W_T, 3712 , narrow) \
114137 f(in_T, out_T, W_T, 4096 , narrow) \
138+ f(in_T, out_T, W_T, 4480 , narrow) \
115139 f(in_T, out_T, W_T, 4608 , narrow) \
140+ f(in_T, out_T, W_T, 4736 , narrow) \
141+ f(in_T, out_T, W_T, 4864 , narrow) \
116142 f(in_T, out_T, W_T, 5120 , narrow) \
117143 f(in_T, out_T, W_T, 5504 , narrow) \
118144 f(in_T, out_T, W_T, 5632 , narrow) \
@@ -122,24 +148,32 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
122148 f(in_T, out_T, W_T, 6848 , narrow) \
123149 f(in_T, out_T, W_T, 6912 , narrow) \
124150 f(in_T, out_T, W_T, 7168 , narrow) \
151+ f(in_T, out_T, W_T, 7424 , narrow) \
125152 f(in_T, out_T, W_T, 8192 , narrow) \
153+ f(in_T, out_T, W_T, 8960 , narrow) \
126154 f(in_T, out_T, W_T, 9216 , narrow) \
155+ f(in_T, out_T, W_T, 9472 , narrow) \
127156 f(in_T, out_T, W_T, 10240 , narrow) \
128157 f(in_T, out_T, W_T, 11008 , narrow) \
129158 f(in_T, out_T, W_T, 11264 , narrow) \
130159 f(in_T, out_T, W_T, 12288 , narrow) \
131160 f(in_T, out_T, W_T, 13696 , narrow) \
132161 f(in_T, out_T, W_T, 13824 , narrow) \
133162 f(in_T, out_T, W_T, 14336 , narrow) \
163+ f(in_T, out_T, W_T, 14784 , narrow) \
164+ f(in_T, out_T, W_T, 14848 , narrow) \
134165 f(in_T, out_T, W_T, 15360 , narrow) \
135166 f(in_T, out_T, W_T, 16384 , narrow) \
167+ f(in_T, out_T, W_T, 18944 , narrow) \
136168 f(in_T, out_T, W_T, 20480 , narrow) \
137169 f(in_T, out_T, W_T, 22016 , narrow) \
138170 f(in_T, out_T, W_T, 22528 , narrow) \
139171 f(in_T, out_T, W_T, 24576 , narrow) \
140172 f(in_T, out_T, W_T, 27392 , narrow) \
141173 f(in_T, out_T, W_T, 27648 , narrow) \
142174 f(in_T, out_T, W_T, 28672 , narrow) \
175+ f(in_T, out_T, W_T, 29568 , narrow) \
176+ f(in_T, out_T, W_T, 29696 , narrow) \
143177 f(in_T, out_T, W_T, 32000 , narrow) \
144178 f(in_T, out_T, W_T, 32256 , narrow) \
145179 f(in_T, out_T, W_T, 32512 , narrow) \
0 commit comments