Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ steps:
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_fusion_attn.py
- pytest -v -s compile/test_functionalization.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
Expand Down
298 changes: 228 additions & 70 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,237 @@
import torch

import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform

from .backend import TestBackend

OPS_IN_MODEL = [
torch.ops._C.rotary_embedding.default,
torch.ops._C.fused_add_rms_norm.default,
]
TEST_FP8 = current_platform.supports_fp8()
FP8_DTYPE = current_platform.fp8_dtype()


class TestSiluMul(torch.nn.Module):

def __init__(self, hidden_size: int = 128):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)

if TEST_FP8:
self.w = torch.rand(hidden_size,
hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)

def forward(self, x):
y = self.silu_and_mul(x)
if TEST_FP8:
x2 = self.fp8_linear.apply(y,
self.w,
self.wscale,
input_scale=self.wscale)
return x2
else:
return y

def example_inputs(self, num_tokens=32, hidden_size=128):
dtype = torch.float16 if TEST_FP8 else torch.float32
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), )

def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
return [torch.ops._C.silu_and_mul_quant.default]
else:
return [torch.ops._C.silu_and_mul.default]

def ops_not_in_model(self):
return []


class TestFusedAddRMSNorm(torch.nn.Module):

def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size

dtype = torch.float16 if TEST_FP8 else torch.float32

self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size), dtype=dtype))
self.norm = RMSNorm(intermediate_size, 1e-05)
self.norm.weight = torch.nn.Parameter(
torch.ones(intermediate_size, dtype=dtype))

torch.nn.init.normal_(self.gate_proj, std=0.02)

if TEST_FP8:
self.fp8_linear = Fp8LinearOp(act_quant_static=True)

self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size,
intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)

def forward(self, hidden_states, residual):
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)

# matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)

# layer normalization
norm_output, residual_output = self.norm(mm, residual)

if TEST_FP8:
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)

return fp8_linear_result, residual_output

else:
return norm_output, residual_output

def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
dtype = torch.float16 if TEST_FP8 else torch.float32
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
return (hidden_states, residual)

RMS_OP = torch.ops._C.rms_norm.default
def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
else:
return [torch.ops._C.fused_add_rms_norm.default]

RMS_QUANT_OPS = {
"static_fp8": [
torch.ops._C.rms_norm_static_fp8_quant.default,
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
],
}
def ops_not_in_model(self):
return []

SILU_MUL_OP = torch.ops._C.silu_and_mul.default

SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
class TestRotaryEmbedding(torch.nn.Module):

def __init__(self,
head_dim=64,
rotary_dim=None,
max_position=2048,
base=10000):
super().__init__()
self.head_dim = head_dim
self.rotary_dim = rotary_dim or head_dim

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
base=base,
)

def forward(self, positions, q, k):
q_rotated, k_rotated = self.rotary_emb(positions, q, k)
return q_rotated, k_rotated

def example_inputs(self, num_tokens=32, head_dim=64):
dtype = torch.float16
positions = torch.arange(num_tokens, dtype=torch.long)
q = torch.randn(num_tokens, head_dim, dtype=dtype)
k = torch.randn(num_tokens, head_dim, dtype=dtype)
return (positions, q, k)

def ops_in_model(self, do_fusion):
return [torch.ops._C.rotary_embedding.default]

def ops_not_in_model(self):
return []


class TestRotaryEmbeddingSliceScatter(torch.nn.Module):

def __init__(self,
head_dim=64,
num_heads=4,
max_position=2048,
base=10000):
super().__init__()
self.head_dim = head_dim
self.num_heads = num_heads
self.hidden_size = head_dim * num_heads

self.qkv_proj = torch.nn.Linear(self.hidden_size,
self.hidden_size * 3,
bias=False,
dtype=torch.float16)

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=base,
)

def forward(self, positions, hidden_states):
# Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
# -> slice_scatter -> split_with_sizes

qkv = self.qkv_proj(hidden_states)
split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size]
q, k, v = torch.split(qkv, split_sizes, dim=-1)

q_rotated, k_rotated = self.rotary_emb(positions, q, k)

qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1)
return qkv_updated

def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
dtype = torch.float16
hidden_size = head_dim * num_heads
positions = torch.arange(num_tokens, dtype=torch.long)
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
return (positions, hidden_states)

def ops_in_model(self, do_fusion):
return [torch.ops._C.rotary_embedding.default]

def ops_not_in_model(self):
return [torch.ops.aten.slice_scatter.default]


MODELS = [
TestSiluMul,
TestFusedAddRMSNorm,
TestRotaryEmbedding,
TestRotaryEmbeddingSliceScatter,
]


@pytest.mark.parametrize(
"model, quant_key",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
kFp8DynamicTokenSym)])
@pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fix_functionalization(model: str, quant_key: QuantKey,
do_fusion: bool):
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
torch.set_default_device("cuda")

vllm_config = VllmConfig()
Expand All @@ -63,56 +246,31 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)

passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
] if do_fusion else [noop_pass, cleanup_pass]
passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion else [noop_pass, cleanup_pass])
func_pass = FixFunctionalizationPass(vllm_config)

backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)

# instantiate a full engine and manually compile the model 2x
# (with and without FixFunctionalizationPass)
llm = LLM(model=model, enforce_eager=True)
model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
orig_model = model_runner.model
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
# Can only do that by using the decorator but then we'd have to instantiate
# 2 LLM instances.

sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_func)
gen_func = llm.generate(prompts, sampling_params)

model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_no_func)

gen_no_func = llm.generate(prompts, sampling_params)

for output_func, output_no_func in zip(gen_func, gen_no_func):
assert output_func.outputs[0].text == output_no_func.outputs[0].text

# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
] if do_fusion else [RMS_OP]
silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
quant_key == kFp8StaticTensorSym else [
SILU_MUL_OP
]

ops = OPS_IN_MODEL + rms_ops + silu_mul_ops

for op in ops:
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())

# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
op) is None # noqa: E501
assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op)
is None) # noqa: E501

# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in ops:
for op in model.ops_in_model(do_fusion):
if is_func(node, op):
found[op] = True
for op in model.ops_not_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in ops)
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())
Loading