11import contextlib
22import functools
3- from typing import List , Optional , Tuple , Union
3+ from typing import TYPE_CHECKING , List , Optional , Tuple , Union
44
55import torch
6+ import torch .library
67
78import vllm .envs as envs
89from vllm ._core_ext import ScalarType
2526 import vllm ._moe_C # noqa: F401
2627 supports_moe_ops = True
2728
29+ if TYPE_CHECKING :
30+
31+ def register_fake (fn ):
32+ return lambda name : fn
33+ else :
34+ try :
35+ from torch .library import register_fake
36+ except ImportError :
37+ from torch .library import impl_abstract as register_fake
38+
2839
2940def hint_on_error (fn ):
3041
@@ -266,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
266277
267278if hasattr (torch .ops ._C , "gptq_gemm" ):
268279
269- @torch . library . register_fake ("_C::gptq_gemm" )
280+ @register_fake ("_C::gptq_gemm" )
270281 def _gptq_gemm_fake (a : torch .Tensor , b_q_weight : torch .Tensor ,
271282 b_gptq_qzeros : torch .Tensor ,
272283 b_gptq_scales : torch .Tensor , b_g_idx : torch .Tensor ,
@@ -301,15 +312,15 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
301312
302313if hasattr (torch .ops ._C , "gptq_marlin_24_gemm" ):
303314
304- @torch . library . register_fake ("_C::gptq_marlin_24_gemm" )
315+ @register_fake ("_C::gptq_marlin_24_gemm" )
305316 def _gptq_marlin_24_gemm_fake (a : torch .Tensor , b_q_weight : torch .Tensor ,
306317 b_meta : torch .Tensor , b_scales : torch .Tensor ,
307318 workspace : torch .Tensor ,
308319 b_q_type : ScalarType , size_m : int ,
309320 size_n : int , size_k : int ) -> torch .Tensor :
310321 return torch .empty ((size_m , size_n ), device = a .device , dtype = a .dtype )
311322
312- @torch . library . register_fake ("_C::gptq_marlin_gemm" )
323+ @register_fake ("_C::gptq_marlin_gemm" )
313324 def _gptq_marlin_gemm_fake (a : torch .Tensor ,
314325 b_q_weight : torch .Tensor ,
315326 b_scales : torch .Tensor ,
@@ -326,12 +337,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
326337 use_fp32_reduce : bool = False ) -> torch .Tensor :
327338 return torch .empty ((size_m , size_n ), device = a .device , dtype = a .dtype )
328339
329- @torch . library . register_fake ("_C::ggml_dequantize" )
340+ @register_fake ("_C::ggml_dequantize" )
330341 def _ggml_dequantize_fake (W : torch .Tensor , quant_type : int , m : int ,
331342 n : int ) -> torch .Tensor :
332343 return torch .empty ((m , n ), dtype = torch .float16 , device = W .device )
333344
334- @torch . library . register_fake ("_C::ggml_mul_mat_vec_a8" )
345+ @register_fake ("_C::ggml_mul_mat_vec_a8" )
335346 def _ggml_mul_mat_vec_a8_fake (
336347 W : torch .Tensor ,
337348 X : torch .Tensor ,
@@ -340,7 +351,7 @@ def _ggml_mul_mat_vec_a8_fake(
340351 ) -> torch .Tensor :
341352 return torch .empty ((1 , row ), dtype = torch .float16 , device = W .device )
342353
343- @torch . library . register_fake ("_C::ggml_mul_mat_a8" )
354+ @register_fake ("_C::ggml_mul_mat_a8" )
344355 def _ggml_mul_mat_a8_fake (
345356 W : torch .Tensor ,
346357 X : torch .Tensor ,
@@ -350,7 +361,7 @@ def _ggml_mul_mat_a8_fake(
350361 batch = X .size (0 )
351362 return torch .empty ((batch , row ), dtype = torch .float16 , device = W .device )
352363
353- @torch . library . register_fake ("_C::marlin_qqq_gemm" )
364+ @register_fake ("_C::marlin_qqq_gemm" )
354365 def _marlin_qqq_gemm_fake (a : torch .Tensor , b_q_weight : torch .Tensor ,
355366 s_tok : torch .Tensor , s_ch : torch .Tensor ,
356367 s_group : torch .Tensor , workspace : torch .Tensor ,
@@ -360,7 +371,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
360371 dtype = torch .float16 ,
361372 device = a .device )
362373
363- @torch . library . register_fake ("_C::marlin_gemm" )
374+ @register_fake ("_C::marlin_gemm" )
364375 def _marlin_gemm_fake (a : torch .Tensor , b_q_weight : torch .Tensor ,
365376 b_scales : torch .Tensor , workspace : torch .Tensor ,
366377 size_m : int , size_n : int ,
@@ -369,7 +380,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
369380 dtype = torch .float16 ,
370381 device = a .device )
371382
372- @torch . library . register_fake ("_C::awq_dequantize" )
383+ @register_fake ("_C::awq_dequantize" )
373384 def _awq_dequantize_fake (qweight : torch .Tensor , scales : torch .Tensor ,
374385 zeros : torch .Tensor , split_k_iters : int , thx : int ,
375386 thy : int ) -> torch .Tensor :
@@ -380,7 +391,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
380391 dtype = scales .dtype ,
381392 device = scales .device )
382393
383- @torch . library . register_fake ("_C::awq_gemm" )
394+ @register_fake ("_C::awq_gemm" )
384395 def _awq_gemm_fake (input : torch .Tensor , qweight : torch .Tensor ,
385396 qzeros : torch .Tensor , scales : torch .Tensor ,
386397 split_k_iters : int ) -> torch .Tensor :
@@ -389,7 +400,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
389400 dtype = input .dtype ,
390401 device = input .device ).sum (0 )
391402
392- @torch . library . register_fake ("_C::aqlm_gemm" )
403+ @register_fake ("_C::aqlm_gemm" )
393404 def _aqlm_gemm_fake (input : torch .Tensor , codes : torch .Tensor ,
394405 codebooks : torch .Tensor , scales : torch .Tensor ,
395406 codebook_partition_sizes : List [int ],
@@ -405,7 +416,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
405416 output_sizes .append (- 1 )
406417 return flat_output .reshape (tuple (output_sizes ))
407418
408- @torch . library . register_fake ("_C::aqlm_dequant" )
419+ @register_fake ("_C::aqlm_dequant" )
409420 def _aqlm_dequant_fake (
410421 codes : torch .Tensor , codebooks : torch .Tensor ,
411422 codebook_partition_sizes : List [int ]) -> torch .Tensor :
@@ -415,14 +426,14 @@ def _aqlm_dequant_fake(
415426 dtype = codebooks .dtype ,
416427 device = codebooks .device )
417428
418- @torch . library . register_fake ("_C::fp8_marlin_gemm" )
429+ @register_fake ("_C::fp8_marlin_gemm" )
419430 def _fp8_marlin_gemm_fake (a : torch .Tensor , b_q_weight : torch .Tensor ,
420431 b_scales : torch .Tensor , workspace : torch .Tensor ,
421432 num_bits : int , size_m : int , size_n : int ,
422433 size_k : int ) -> torch .Tensor :
423434 return torch .empty ((size_m , size_n ), dtype = a .dtype , device = a .device )
424435
425- @torch . library . register_fake ("_C::machete_gemm" )
436+ @register_fake ("_C::machete_gemm" )
426437 def machete_gemm_fake (
427438 a : torch .Tensor ,
428439 # Should be the tensor returned by machete_prepack_B
@@ -440,13 +451,13 @@ def machete_gemm_fake(
440451 n = b_q .size (1 )
441452 return torch .empty ((m , n ), device = a .device , dtype = a .dtype )
442453
443- @torch . library . register_fake ("_C::machete_prepack_B" )
454+ @register_fake ("_C::machete_prepack_B" )
444455 def machete_prepack_B_fake (b_q_weight : torch .Tensor ,
445456 b_type : ScalarType ) -> torch .Tensor :
446457 return torch .empty_like (b_q_weight ,
447458 memory_format = torch .contiguous_format )
448459
449- @torch . library . register_fake ("_C::causal_conv1d_fwd" )
460+ @register_fake ("_C::causal_conv1d_fwd" )
450461 def causal_conv1d_fwd_fake (x : torch .Tensor , weight : torch .Tensor ,
451462 bias_ : Optional [torch .Tensor ],
452463 conv_states : Optional [torch .Tensor ],
@@ -456,15 +467,15 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
456467 silu_activation : bool ) -> torch .Tensor :
457468 return torch .empty_like (x )
458469
459- @torch . library . register_fake ("_C::causal_conv1d_update" )
470+ @register_fake ("_C::causal_conv1d_update" )
460471 def causal_conv1d_update_fake (
461472 x : torch .Tensor , conv_state : torch .Tensor , weight : torch .Tensor ,
462473 bias_ : Optional [torch .Tensor ], silu_activation : bool ,
463474 cache_seqlens : Optional [torch .Tensor ],
464475 conv_state_indices : Optional [torch .Tensor ]) -> torch .Tensor :
465476 return torch .empty_like (x )
466477
467- @torch . library . register_fake ("_C::selective_scan_fwd" )
478+ @register_fake ("_C::selective_scan_fwd" )
468479 def selective_scan_fwd_fake (u : torch .Tensor , delta : torch .Tensor ,
469480 A : torch .Tensor , B : torch .Tensor ,
470481 C : torch .Tensor , D_ : Optional [torch .Tensor ],
@@ -639,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
639650
640651if hasattr (torch .ops ._C , "permute_cols" ):
641652
642- @torch . library . register_fake ("_C::permute_cols" )
653+ @register_fake ("_C::permute_cols" )
643654 def _permute_cols_fake (a : torch .Tensor ,
644655 perm : torch .Tensor ) -> torch .Tensor :
645656 return torch .empty_like (a )
@@ -837,7 +848,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
837848
838849if supports_moe_ops and hasattr (torch .ops ._moe_C , "marlin_gemm_moe" ):
839850
840- @torch . library . register_fake ("_moe_C::marlin_gemm_moe" )
851+ @register_fake ("_moe_C::marlin_gemm_moe" )
841852 def marlin_gemm_moe_fake (a : torch .Tensor , b_q_weights : torch .Tensor ,
842853 sorted_ids : torch .Tensor ,
843854 topk_weights : torch .Tensor ,
0 commit comments