44#include " core/registration.h"
55
66#include < torch/library.h>
7+ #include < torch/version.h>
78
89// Note on op signatures:
910// The X_meta signatures are for the meta functions corresponding to op X.
1718
1819TORCH_LIBRARY_EXPAND (TORCH_EXTENSION_NAME, ops) {
1920 // vLLM custom ops
21+ //
22+
23+ // The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
24+ // to override this for many GEMMs with the following tag. Otherwise,
25+ // torch.compile will force all input tensors to be contiguous(), which
26+ // will break many custom ops that require column-major weight matrices.
27+ // TODO: remove this for PyTorch 2.8, when the default is planned to switch
28+ // to match exact eager-mode strides.
29+ at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
2030
2131 ops.def (" weak_ref_tensor(Tensor input) -> Tensor" );
2232 ops.impl (" weak_ref_tensor" , torch::kCUDA , &weak_ref_tensor);
@@ -163,25 +173,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
163173 ops.def (
164174 " aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
165175 " Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
166- " -> Tensor" );
176+ " -> Tensor" ,
177+ {stride_tag});
167178 ops.impl (" aqlm_gemm" , torch::kCUDA , &aqlm_gemm);
168179
169180 // Decompression method for AQLM.
170181 ops.def (
171182 " aqlm_dequant(Tensor codes, Tensor codebooks, "
172- " int[] codebook_partition_sizes) -> Tensor" );
183+ " int[] codebook_partition_sizes) -> Tensor" ,
184+ {stride_tag});
173185 ops.impl (" aqlm_dequant" , torch::kCUDA , &aqlm_dequant);
174186
175187 // Quantized GEMM for AWQ.
176188 ops.def (
177189 " awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
178- " Tensor _zeros, SymInt split_k_iters) -> Tensor" );
190+ " Tensor _zeros, SymInt split_k_iters) -> Tensor" ,
191+ {stride_tag});
179192 ops.impl (" awq_gemm" , torch::kCUDA , &awq_gemm);
180193
181194 // Dequantization for AWQ.
182195 ops.def (
183196 " awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
184- " Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor" );
197+ " Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor" ,
198+ {stride_tag});
185199 ops.impl (" awq_dequantize" , torch::kCUDA , &awq_dequantize);
186200
187201 // Note about marlin kernel 'workspace' arguments:
@@ -202,15 +216,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
202216 ops.def (
203217 " marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
204218 " Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
205- " Tensor" );
219+ " Tensor" ,
220+ {stride_tag});
206221 // conditionally compiled so impl in source file
207222
208223 // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
209224 ops.def (
210225 " gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
211226 " Tensor b_scales, Tensor workspace, "
212227 " int b_q_type, "
213- " SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor" );
228+ " SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor" ,
229+ {stride_tag});
214230 // conditionally compiled so impl in source file
215231
216232 // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@@ -236,7 +252,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
236252 " Tensor? channel_scales,"
237253 " Tensor? token_scales,"
238254 " str? schedule"
239- " ) -> Tensor" );
255+ " ) -> Tensor" ,
256+ {stride_tag});
240257 ops.def (
241258 " machete_prepack_B("
242259 " Tensor B,"
@@ -255,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
255272 " Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
256273 " int b_q_type, "
257274 " SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
258- " bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor" );
275+ " bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor" ,
276+ {stride_tag});
259277 // conditionally compiled so impl registration is in source file
260278
261279 // gptq_marlin repack from GPTQ.
@@ -291,30 +309,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
291309 ops.def (
292310 " fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
293311 " Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
294- " SymInt size_k) -> Tensor" );
312+ " SymInt size_k) -> Tensor" ,
313+ {stride_tag});
295314 // conditionally compiled so impl registration is in source file
296315
297316 // marlin_qqq_gemm for QQQ.
298317 ops.def (
299318 " marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
300319 " Tensor s_tok, Tensor s_ch, Tensor s_group, "
301320 " Tensor! workspace, SymInt size_m, SymInt size_n, "
302- " SymInt size_k) -> Tensor" );
321+ " SymInt size_k) -> Tensor" ,
322+ {stride_tag});
303323 // conditionally compiled so impl registration is in source file
304324
305325 // CUTLASS nvfp4 block scaled GEMM
306326 ops.def (
307327 " cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
308328 " Tensor block_scale_a, Tensor block_scale_b,"
309- " Tensor alpha) -> ()" );
329+ " Tensor alpha) -> ()" ,
330+ {stride_tag});
310331 ops.impl (" cutlass_scaled_fp4_mm" , torch::kCUDA , &cutlass_scaled_fp4_mm);
311332
312333 // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
313334 // quantization, as well as bias
314335 ops.def (
315336 " cutlass_scaled_mm(Tensor! out, Tensor a,"
316337 " Tensor b, Tensor a_scales,"
317- " Tensor b_scales, Tensor? bias) -> ()" );
338+ " Tensor b_scales, Tensor? bias) -> ()" ,
339+ {stride_tag});
318340 ops.impl (" cutlass_scaled_mm" , torch::kCUDA , &cutlass_scaled_mm);
319341
320342 // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
@@ -323,7 +345,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
323345 " cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
324346 " Tensor b, Tensor a_scales,"
325347 " Tensor b_scales, Tensor azp_adj,"
326- " Tensor? azp, Tensor? bias) -> ()" );
348+ " Tensor? azp, Tensor? bias) -> ()" ,
349+ {stride_tag});
327350 ops.impl (" cutlass_scaled_mm_azp" , torch::kCUDA , &cutlass_scaled_mm_azp);
328351
329352 // Check if cutlass scaled_mm is supported for CUDA devices of the given
@@ -351,7 +374,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
351374 " cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
352375 " Tensor bt_nzs,"
353376 " Tensor bt_meta, Tensor a_scales,"
354- " Tensor b_scales, Tensor? bias) -> ()" );
377+ " Tensor b_scales, Tensor? bias) -> ()" ,
378+ {stride_tag});
355379 ops.impl (" cutlass_scaled_sparse_mm" , torch::kCUDA , &cutlass_scaled_sparse_mm);
356380
357381 // CUTLASS sparse matrix compressor
@@ -407,7 +431,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
407431 ops.def (
408432 " gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
409433 " Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
410- " -> Tensor" );
434+ " -> Tensor" ,
435+ {stride_tag});
411436 ops.impl (" gptq_gemm" , torch::kCUDA , &gptq_gemm);
412437
413438 // Post processing for GPTQ.
0 commit comments