@@ -482,41 +482,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
482482      "                    Tensor page_table, float scale) -> ()" 
483483  ops.impl (" cutlass_mla_decode" kCUDA , &cutlass_mla_decode);
484484
485-   //  Mamba selective scan kernel
486-   ops.def (
487-       " selective_scan_fwd(Tensor! u, Tensor! delta," 
488-       " Tensor! A, Tensor! B, Tensor! C," 
489-       " Tensor? D_, Tensor!? z_, Tensor? delta_bias_," 
490-       " bool delta_softplus," 
491-       " Tensor? query_start_loc," 
492-       " Tensor? cache_indices," 
493-       " Tensor? has_initial_state," 
494-       " Tensor! ssm_states," 
495-       " int pad_slot_id) -> ()" 
496-   ops.impl (" selective_scan_fwd" kCUDA , &selective_scan_fwd);
497- 
498-   ops.def (
499-       " causal_conv1d_update(Tensor! x," 
500-       " Tensor! conv_state," 
501-       " Tensor! weight," 
502-       " Tensor? bias_," 
503-       " bool silu_activation," 
504-       " Tensor? cache_seqlens_," 
505-       " Tensor? conv_state_indices," 
506-       " int pad_slot_id) -> ()" 
507-   ops.impl (" causal_conv1d_update" kCUDA , &causal_conv1d_update);
508- 
509-   ops.def (
510-       " causal_conv1d_fwd(Tensor! x, Tensor! weight," 
511-       " Tensor? bias_," 
512-       " Tensor!? conv_states," 
513-       " Tensor? query_start_loc," 
514-       " Tensor? cache_indices," 
515-       " Tensor? has_initial_state," 
516-       " bool silu_activation," 
517-       " int pad_slot_id) -> ()" 
518-   ops.impl (" causal_conv1d_fwd" kCUDA , &causal_conv1d_fwd);
519- 
520485  //  Compute NVFP4 block quantized tensor.
521486  ops.def (
522487      " scaled_fp4_quant(Tensor! output, Tensor input," 
@@ -584,6 +549,41 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
584549  ops.impl (" dynamic_scaled_int8_quant" kCUDA ,
585550           &dynamic_scaled_int8_quant);
586551
552+   //  Mamba selective scan kernel
553+   ops.def (
554+       " selective_scan_fwd(Tensor! u, Tensor! delta," 
555+       " Tensor! A, Tensor! B, Tensor! C," 
556+       " Tensor? D_, Tensor!? z_, Tensor? delta_bias_," 
557+       " bool delta_softplus," 
558+       " Tensor? query_start_loc," 
559+       " Tensor? cache_indices," 
560+       " Tensor? has_initial_state," 
561+       " Tensor! ssm_states," 
562+       " int pad_slot_id) -> ()" 
563+   ops.impl (" selective_scan_fwd" kCUDA , &selective_scan_fwd);
564+ 
565+   ops.def (
566+       " causal_conv1d_update(Tensor! x," 
567+       " Tensor! conv_state," 
568+       " Tensor! weight," 
569+       " Tensor? bias_," 
570+       " bool silu_activation," 
571+       " Tensor? cache_seqlens_," 
572+       " Tensor? conv_state_indices," 
573+       " int pad_slot_id) -> ()" 
574+   ops.impl (" causal_conv1d_update" kCUDA , &causal_conv1d_update);
575+ 
576+   ops.def (
577+       " causal_conv1d_fwd(Tensor! x, Tensor! weight," 
578+       " Tensor? bias_," 
579+       " Tensor!? conv_states," 
580+       " Tensor? query_start_loc," 
581+       " Tensor? cache_indices," 
582+       " Tensor? has_initial_state," 
583+       " bool silu_activation," 
584+       " int pad_slot_id) -> ()" 
585+   ops.impl (" causal_conv1d_fwd" kCUDA , &causal_conv1d_fwd);
586+ 
587587#ifndef  USE_ROCM
588588  //  reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
589589  ops.def (
0 commit comments