-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
Description
Motivation.
To leverage vLLM V1 architecture change, we are trying to propose a new integration apporach for the neuron backend that seamlessly integrate with vLLM, while maintaining high-performance and taking prefix-caching as first-class feature.
Background
(ref: #8779)
vLLM is on a path toward 1/ full support for torch.compile, 2/ turn on chunked prefill, prefix caching, speculative decoding by default, 3/ support more than 60 model variants. While, current neuron backend is supported via the transformers-neuronx library, which has limited support to the combination of these feature.
To support a wide range of model variants, vLLM has been maintaining the modular design with vllm.model_executor.layers module. This enables new model developers easily contribute to vLLM to support new model variants. For instance, Mistral team released pixtral-large model weights and brought pixtral-large model support with Pixtral (vllm-project#8377).
Proposed Change.
Embracing torch.compile support
As part of Neuron SDK 2.21 release, we are able to support torch.compile with openxla backend. For instance, we can implement copy_blocks with
class NeuronAttentionBackend:
@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
if use_torch_compile:
compiled_code = torch.compile(PallasAttentionBackend.copy_blocks, backend='openxla')
compiled_code(kv_caches, src_to_dists)
else:
PallasAttentionBackend.copy_blocks(kv_caches, src_to_dists)
Build neuron attention backend with NKI
We may build NKI-based flash-attention with paged KV cache, as part of vllm.attention.ops module. This is similar to triton-lang based flash-attention in vLLM (ref: triton_flash_attention.py).
Introduce forward_neuron into vllm.model_executor.layers module
Like many other backends in vLLM, the native kernel may not be highly performant on a specific hardware backend. vLLM has been building and maintaining the default behavior with forward_native function call, while hardware-specific optimization can be enabled with forward_xxx function call.
We can reuse the performant components in neuronx_distributed package to further improvement performance.
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
"""
def forward_neuron(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from neuronx_distributed.ops import NeuronFusedRMSNorm
if NeuronFusedRMSNorm is None:
return self.forward_native(x, residual)
if residual is not None:
orig_shape = x.shape
residual += x.view(residual.shape)
x = NeuronFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual
x = NeuronFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
return x
Development Progress
NKI-based Flash-attention kernel
- [Neuron][Kernel] NKI-based flash-attention kernel with paged KV cache #11277
- [Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to maximize DMA bandwidth #13245
- [Neuron][Kernel] NKI Flash PagedAttention with BlockSparse Execution Plan #13249
End-to-end full model support
- [Neuron] Add custom_ops for neuron backend #13246
- [Neuron] Add Neuron device communicator for vLLM v1 #14085
- [neuron] add reshape_and_cache #14391
- [Neuron][V1] Experimental support for neuron backend with V1 architecture #14648
Test infra improvement
- Test multiple scripts by running bash - [Docker] bump up neuron sdk v2.21 #11593 [CI] Fix neuron CI and run offline tests #11779