66import torch .nn as nn
77import torch .nn .functional as F
88
9- from vllm import _custom_ops as ops
109from vllm .distributed import (divide , get_tensor_model_parallel_rank ,
1110 get_tensor_model_parallel_world_size )
11+ from vllm .model_executor .custom_op import CustomOp
1212from vllm .model_executor .layers .quantization import QuantizationConfig
1313from vllm .model_executor .utils import set_weight_attrs
1414
1515
16- class SiluAndMul (nn . Module ):
16+ class SiluAndMul (CustomOp ):
1717 """An activation function for SwiGLU.
1818
1919 The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
@@ -23,20 +23,22 @@ class SiluAndMul(nn.Module):
2323 return: (num_tokens, d) or (batch_size, seq_len, d)
2424 """
2525
26- def _forward (self , x : torch .Tensor ) -> torch .Tensor :
26+ def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
2727 """PyTorch-native implementation equivalent to forward()."""
2828 d = x .shape [- 1 ] // 2
2929 return F .silu (x [..., :d ]) * x [..., d :]
3030
31- def forward (self , x : torch .Tensor ) -> torch .Tensor :
31+ def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
32+ from vllm import _custom_ops as ops
33+
3234 d = x .shape [- 1 ] // 2
3335 output_shape = (x .shape [:- 1 ] + (d , ))
3436 out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
3537 ops .silu_and_mul (out , x )
3638 return out
3739
3840
39- class GeluAndMul (nn . Module ):
41+ class GeluAndMul (CustomOp ):
4042 """An activation function for GeGLU.
4143
4244 The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
@@ -52,12 +54,14 @@ def __init__(self, approximate: str = "none"):
5254 if approximate not in ("none" , "tanh" ):
5355 raise ValueError (f"Unknown approximate mode: { approximate } " )
5456
55- def _forward (self , x : torch .Tensor ) -> torch .Tensor :
57+ def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
5658 """PyTorch-native implementation equivalent to forward()."""
5759 d = x .shape [- 1 ] // 2
5860 return F .gelu (x [..., :d ], approximate = self .approximate ) * x [..., d :]
5961
60- def forward (self , x : torch .Tensor ) -> torch .Tensor :
62+ def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
63+ from vllm import _custom_ops as ops
64+
6165 d = x .shape [- 1 ] // 2
6266 output_shape = (x .shape [:- 1 ] + (d , ))
6367 out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
@@ -71,28 +75,32 @@ def extra_repr(self) -> str:
7175 return f'approximate={ repr (self .approximate )} '
7276
7377
74- class NewGELU (nn . Module ):
78+ class NewGELU (CustomOp ):
7579
76- def _forward (self , x : torch .Tensor ) -> torch .Tensor :
80+ def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
7781 """PyTorch-native implementation equivalent to forward()."""
7882 c = math .sqrt (2.0 / math .pi )
7983 return 0.5 * x * (1.0 + torch .tanh (c *
8084 (x + 0.044715 * torch .pow (x , 3.0 ))))
8185
82- def forward (self , x : torch .Tensor ) -> torch .Tensor :
86+ def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
87+ from vllm import _custom_ops as ops
88+
8389 out = torch .empty_like (x )
8490 ops .gelu_new (out , x )
8591 return out
8692
8793
88- class FastGELU (nn . Module ):
94+ class FastGELU (CustomOp ):
8995
90- def _forward (self , x : torch .Tensor ) -> torch .Tensor :
96+ def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
9197 """PyTorch-native implementation equivalent to forward()."""
9298 return 0.5 * x * (1.0 + torch .tanh (x * 0.7978845608 *
9399 (1.0 + 0.044715 * x * x )))
94100
95- def forward (self , x : torch .Tensor ) -> torch .Tensor :
101+ def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
102+ from vllm import _custom_ops as ops
103+
96104 out = torch .empty_like (x )
97105 ops .gelu_fast (out , x )
98106 return out
0 commit comments