Skip to content

[RFC][Exploratory]: vLLM Neuron Backend with V1 Architecture #11152

@liangfu

Description

@liangfu

Motivation.

To leverage vLLM V1 architecture change, we are trying to propose a new integration apporach for the neuron backend that seamlessly integrate with vLLM, while maintaining high-performance and taking prefix-caching as first-class feature.

Background

(ref: #8779)
vLLM is on a path toward 1/ full support for torch.compile, 2/ turn on chunked prefill, prefix caching, speculative decoding by default, 3/ support more than 60 model variants. While, current neuron backend is supported via the transformers-neuronx library, which has limited support to the combination of these feature.

To support a wide range of model variants, vLLM has been maintaining the modular design with vllm.model_executor.layers module. This enables new model developers easily contribute to vLLM to support new model variants. For instance, Mistral team released pixtral-large model weights and brought pixtral-large model support with Pixtral (vllm-project#8377).

Proposed Change.

Embracing torch.compile support

As part of Neuron SDK 2.21 release, we are able to support torch.compile with openxla backend. For instance, we can implement copy_blocks with

class NeuronAttentionBackend:
    @torch.compile(backend="openxla")
    @staticmethod
    def copy_blocks(
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        src_to_dists: Tuple[torch.Tensor, torch.Tensor],
    ) -> None:
        src_indices, dst_indices = src_to_dists
        for k_cache, v_cache in kv_caches:
            torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
            k_cache[:, dst_indices] = k_cache[:, src_indices]
            torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
            v_cache[:, dst_indices] = v_cache[:, src_indices]

if use_torch_compile:
    compiled_code = torch.compile(PallasAttentionBackend.copy_blocks, backend='openxla')
    compiled_code(kv_caches, src_to_dists)
else:
    PallasAttentionBackend.copy_blocks(kv_caches, src_to_dists)

Build neuron attention backend with NKI

We may build NKI-based flash-attention with paged KV cache, as part of vllm.attention.ops module. This is similar to triton-lang based flash-attention in vLLM (ref: triton_flash_attention.py).

Introduce forward_neuron into vllm.model_executor.layers module

Like many other backends in vLLM, the native kernel may not be highly performant on a specific hardware backend. vLLM has been building and maintaining the default behavior with forward_native function call, while hardware-specific optimization can be enabled with forward_xxx function call.

We can reuse the performant components in neuronx_distributed package to further improvement performance.

@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    """
    def forward_neuron(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        from neuronx_distributed.ops import NeuronFusedRMSNorm
        if NeuronFusedRMSNorm is None:
            return self.forward_native(x, residual)
        if residual is not None:
            orig_shape = x.shape
            residual += x.view(residual.shape)
            x = NeuronFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon)
            return x.view(orig_shape), residual

        x = NeuronFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
        return x

Development Progress

NKI-based Flash-attention kernel

End-to-end full model support

Test infra improvement

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions