Skip to content

Commit 014e9bc

Browse files
WoosukKwonjoerunde
authored andcommitted
[Misc] Add CustomOp interface for device portability (vllm-project#5255)
1 parent 944283d commit 014e9bc

File tree

7 files changed

+100
-27
lines changed

7 files changed

+100
-27
lines changed

tests/kernels/test_activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_act_and_mul(
4444
elif activation == "gelu_tanh":
4545
layer = GeluAndMul(approximate="tanh")
4646
out = layer(x)
47-
ref_out = layer._forward(x)
47+
ref_out = layer.forward_native(x)
4848
# The SiLU and GELU implementations are equivalent to the native PyTorch
4949
# implementations, so we can do exact comparison.
5050
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
@@ -72,7 +72,7 @@ def test_activation(
7272
x = torch.randn(num_tokens, d, dtype=dtype)
7373
layer = activation()
7474
out = layer(x)
75-
ref_out = layer._forward(x)
75+
ref_out = layer.forward_native(x)
7676
assert torch.allclose(out,
7777
ref_out,
7878
atol=get_default_atol(out),

tests/kernels/test_layernorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_rms_norm(
4242

4343
# NOTE(woosuk): The reference implementation should be executed first
4444
# because the custom kernel is in-place.
45-
ref_out = layer._forward(x, residual)
45+
ref_out = layer.forward_native(x, residual)
4646
out = layer(x, residual)
4747
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
4848
# numerical errors than other operators because they involve reductions.

tests/kernels/test_pos_encoding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_rotary_embedding(
6464

6565
# NOTE(woosuk): The reference implementation should be executed first
6666
# because the custom kernel is in-place.
67-
ref_query, ref_key = rope._forward(positions, query, key)
67+
ref_query, ref_key = rope.forward_native(positions, query, key)
6868
out_query, out_key = rope.forward(positions, query, key)
6969
# Compare the results.
7070
assert torch.allclose(out_query,
@@ -121,7 +121,7 @@ def test_batched_rotary_embedding(
121121

122122
# NOTE(woosuk): The reference implementation should be executed first
123123
# because the custom kernel is in-place.
124-
ref_query, ref_key = rope._forward(positions, query, key)
124+
ref_query, ref_key = rope.forward_native(positions, query, key)
125125
out_query, out_key = rope.forward(positions,
126126
query,
127127
key,
@@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora(
195195

196196
# NOTE(woosuk): The reference implementation should be executed first
197197
# because the custom kernel is in-place.
198-
ref_query, ref_key = rope._forward(positions, query, key, query_offsets)
198+
ref_query, ref_key = rope.forward_native(positions, query, key,
199+
query_offsets)
199200
out_query, out_key = rope.forward(positions, query, key,
200201
query_offsets.flatten())
201202
# Compare the results.

vllm/model_executor/custom_op.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch.nn as nn
2+
3+
from vllm.utils import is_cpu, is_hip
4+
5+
6+
class CustomOp(nn.Module):
7+
8+
def __init__(self, *args, **kwargs):
9+
super().__init__()
10+
self._forward_method = self.dispatch_forward()
11+
12+
def forward(self, *args, **kwargs):
13+
return self._forward_method(*args, **kwargs)
14+
15+
def forward_native(self, *args, **kwargs):
16+
"""PyTorch-native implementation of the forward method.
17+
18+
This method is optional. If implemented, it can be used with compilers
19+
such as torch.compile or PyTorch XLA. Also, it can be used for testing
20+
purposes.
21+
"""
22+
raise NotImplementedError
23+
24+
def forward_cuda(self, *args, **kwargs):
25+
raise NotImplementedError
26+
27+
def forward_hip(self, *args, **kwargs):
28+
# By default, we assume that HIP ops are compatible with CUDA ops.
29+
return self.forward_cuda(*args, **kwargs)
30+
31+
def forward_xpu(self, *args, **kwargs):
32+
# By default, we assume that XPU ops are compatible with CUDA ops.
33+
# NOTE(woosuk): This is a placeholder for future extensions.
34+
return self.forward_cuda(*args, **kwargs)
35+
36+
def forward_cpu(self, *args, **kwargs):
37+
# By default, we assume that CPU ops are compatible with CUDA ops.
38+
return self.forward_cuda(*args, **kwargs)
39+
40+
def forward_tpu(self, *args, **kwargs):
41+
# By default, we assume that TPU ops are compatible with the
42+
# PyTorch-native implementation.
43+
# NOTE(woosuk): This is a placeholder for future extensions.
44+
return self.forward_native(*args, **kwargs)
45+
46+
def forward_gaudi(self, *args, **kwargs):
47+
# By default, we assume that Gaudi ops are compatible with the
48+
# PyTorch-native implementation.
49+
# NOTE(woosuk): This is a placeholder for future extensions.
50+
return self.forward_native(*args, **kwargs)
51+
52+
def dispatch_forward(self):
53+
# NOTE(woosuk): Here we assume that vLLM was built for only one
54+
# specific backend. Currently, we do not support dynamic dispatching.
55+
if is_hip():
56+
return self.forward_hip
57+
elif is_cpu():
58+
return self.forward_cpu
59+
else:
60+
return self.forward_cuda

vllm/model_executor/layers/activation.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import torch.nn as nn
77
import torch.nn.functional as F
88

9-
from vllm import _custom_ops as ops
109
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1110
get_tensor_model_parallel_world_size)
11+
from vllm.model_executor.custom_op import CustomOp
1212
from vllm.model_executor.layers.quantization import QuantizationConfig
1313
from vllm.model_executor.utils import set_weight_attrs
1414

1515

16-
class SiluAndMul(nn.Module):
16+
class SiluAndMul(CustomOp):
1717
"""An activation function for SwiGLU.
1818
1919
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
@@ -23,20 +23,22 @@ class SiluAndMul(nn.Module):
2323
return: (num_tokens, d) or (batch_size, seq_len, d)
2424
"""
2525

26-
def _forward(self, x: torch.Tensor) -> torch.Tensor:
26+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
2727
"""PyTorch-native implementation equivalent to forward()."""
2828
d = x.shape[-1] // 2
2929
return F.silu(x[..., :d]) * x[..., d:]
3030

31-
def forward(self, x: torch.Tensor) -> torch.Tensor:
31+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
32+
from vllm import _custom_ops as ops
33+
3234
d = x.shape[-1] // 2
3335
output_shape = (x.shape[:-1] + (d, ))
3436
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
3537
ops.silu_and_mul(out, x)
3638
return out
3739

3840

39-
class GeluAndMul(nn.Module):
41+
class GeluAndMul(CustomOp):
4042
"""An activation function for GeGLU.
4143
4244
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
@@ -52,12 +54,14 @@ def __init__(self, approximate: str = "none"):
5254
if approximate not in ("none", "tanh"):
5355
raise ValueError(f"Unknown approximate mode: {approximate}")
5456

55-
def _forward(self, x: torch.Tensor) -> torch.Tensor:
57+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
5658
"""PyTorch-native implementation equivalent to forward()."""
5759
d = x.shape[-1] // 2
5860
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
5961

60-
def forward(self, x: torch.Tensor) -> torch.Tensor:
62+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
63+
from vllm import _custom_ops as ops
64+
6165
d = x.shape[-1] // 2
6266
output_shape = (x.shape[:-1] + (d, ))
6367
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@@ -71,28 +75,32 @@ def extra_repr(self) -> str:
7175
return f'approximate={repr(self.approximate)}'
7276

7377

74-
class NewGELU(nn.Module):
78+
class NewGELU(CustomOp):
7579

76-
def _forward(self, x: torch.Tensor) -> torch.Tensor:
80+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
7781
"""PyTorch-native implementation equivalent to forward()."""
7882
c = math.sqrt(2.0 / math.pi)
7983
return 0.5 * x * (1.0 + torch.tanh(c *
8084
(x + 0.044715 * torch.pow(x, 3.0))))
8185

82-
def forward(self, x: torch.Tensor) -> torch.Tensor:
86+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
87+
from vllm import _custom_ops as ops
88+
8389
out = torch.empty_like(x)
8490
ops.gelu_new(out, x)
8591
return out
8692

8793

88-
class FastGELU(nn.Module):
94+
class FastGELU(CustomOp):
8995

90-
def _forward(self, x: torch.Tensor) -> torch.Tensor:
96+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
9197
"""PyTorch-native implementation equivalent to forward()."""
9298
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
9399
(1.0 + 0.044715 * x * x)))
94100

95-
def forward(self, x: torch.Tensor) -> torch.Tensor:
101+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
102+
from vllm import _custom_ops as ops
103+
96104
out = torch.empty_like(x)
97105
ops.gelu_fast(out, x)
98106
return out

vllm/model_executor/layers/layernorm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import torch
55
import torch.nn as nn
66

7-
from vllm import _custom_ops as ops
7+
from vllm.model_executor.custom_op import CustomOp
88

99

10-
class RMSNorm(nn.Module):
10+
class RMSNorm(CustomOp):
1111
"""Root mean square normalization.
1212
1313
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
@@ -23,7 +23,7 @@ def __init__(
2323
self.weight = nn.Parameter(torch.ones(hidden_size))
2424
self.variance_epsilon = eps
2525

26-
def _forward(
26+
def forward_native(
2727
self,
2828
x: torch.Tensor,
2929
residual: Optional[torch.Tensor] = None,
@@ -43,11 +43,13 @@ def _forward(
4343
else:
4444
return x, residual
4545

46-
def forward(
46+
def forward_cuda(
4747
self,
4848
x: torch.Tensor,
4949
residual: Optional[torch.Tensor] = None,
5050
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
51+
from vllm import _custom_ops as ops
52+
5153
if residual is not None:
5254
ops.fused_add_rms_norm(
5355
x,

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import torch
2828
import torch.nn as nn
2929

30-
from vllm import _custom_ops as ops
30+
from vllm.model_executor.custom_op import CustomOp
3131

3232

3333
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
4343
return x.flatten(-2)
4444

4545

46-
class RotaryEmbedding(nn.Module):
46+
class RotaryEmbedding(CustomOp):
4747
"""Original rotary positional embedding."""
4848

4949
def __init__(
@@ -93,7 +93,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
9393
cache = torch.cat((cos, sin), dim=-1)
9494
return cache
9595

96-
def _forward(
96+
def forward_native(
9797
self,
9898
positions: torch.Tensor,
9999
query: torch.Tensor,
@@ -138,13 +138,15 @@ def _forward(
138138
key = key.flatten(-2)
139139
return query, key
140140

141-
def forward(
141+
def forward_cuda(
142142
self,
143143
positions: torch.Tensor,
144144
query: torch.Tensor,
145145
key: torch.Tensor,
146146
offsets: Optional[torch.Tensor] = None,
147147
) -> Tuple[torch.Tensor, torch.Tensor]:
148+
from vllm import _custom_ops as ops
149+
148150
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
149151
dtype=query.dtype)
150152
# ops.rotary_embedding()/batched_rotary_embedding()

0 commit comments

Comments
 (0)