Skip to content

Commit 0720d52

Browse files
youkaichaorasmith
authored andcommitted
[optimization] remove python function call for custom op (vllm-project#11750)
Signed-off-by: youkaichao <[email protected]>
1 parent 20bd63c commit 0720d52

File tree

4 files changed

+15
-13
lines changed

4 files changed

+15
-13
lines changed

vllm/_custom_ops.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ def register_fake(fn):
3535

3636

3737
# activation ops
38-
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
39-
torch.ops._C.silu_and_mul(out, x)
40-
41-
4238
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
4339
torch.ops._C.gelu_and_mul(out, x)
4440

vllm/model_executor/layers/activation.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
get_tensor_model_parallel_world_size)
1111
from vllm.model_executor.custom_op import CustomOp
1212
from vllm.model_executor.utils import set_weight_attrs
13+
from vllm.platforms import current_platform
1314
from vllm.utils import LazyDict
1415

1516

@@ -58,27 +59,31 @@ class SiluAndMul(CustomOp):
5859
return: (num_tokens, d) or (batch_size, seq_len, d)
5960
"""
6061

62+
def __init__(self):
63+
super().__init__()
64+
if current_platform.is_cuda_alike():
65+
self.op = torch.ops._C.silu_and_mul
66+
elif current_platform.is_xpu():
67+
import intel_extension_for_pytorch as ipex
68+
self.op = ipex.llm.functional.silu_and_mul
69+
6170
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
6271
"""PyTorch-native implementation equivalent to forward()."""
6372
d = x.shape[-1] // 2
6473
return F.silu(x[..., :d]) * x[..., d:]
6574

6675
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
67-
from vllm import _custom_ops as ops
68-
6976
d = x.shape[-1] // 2
7077
output_shape = (x.shape[:-1] + (d, ))
7178
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
72-
ops.silu_and_mul(out, x)
79+
self.op(out, x)
7380
return out
7481

7582
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
76-
from vllm._ipex_ops import ipex_ops as ops
77-
7883
d = x.shape[-1] // 2
7984
output_shape = (x.shape[:-1] + (d, ))
8085
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
81-
ops.silu_and_mul(out, x)
86+
self.op(out, x)
8287
return out
8388

8489

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66

7-
from vllm import _custom_ops as ops
87
from vllm.model_executor.layers.fused_moe.fused_moe import (
98
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
109
from vllm.scalar_type import scalar_types
@@ -301,7 +300,8 @@ def fused_marlin_moe(
301300
False,
302301
)
303302

304-
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
303+
torch.ops._C.silu_and_mul(intermediate_cache2,
304+
intermediate_cache1.view(-1, 2 * N))
305305

306306
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
307307
intermediate_cache2,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
753753
use_int8_w8a16=use_int8_w8a16,
754754
block_shape=block_shape)
755755

756-
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
756+
torch.ops._C.silu_and_mul(intermediate_cache2,
757+
intermediate_cache1.view(-1, N))
757758

758759
invoke_fused_moe_kernel(intermediate_cache2,
759760
w2,

0 commit comments

Comments
 (0)