Skip to content
5 changes: 3 additions & 2 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from vllm import LLM, SamplingParams
import torch

# Sample prompts.
prompts = [
Expand All @@ -8,10 +9,10 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams(temperature=0.0, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="meta-llama/Meta-Llama-3-8b", tensor_parallel_size=2, enforce_eager=True, dtype=torch.float16)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
17 changes: 17 additions & 0 deletions flux_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#Point to the directory containing the flux .so files:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/nm-vllm/flux_experiment/lib

export NVSHMEM_BOOTSTRAP_MPI_PLUGIN=nvshmem_bootstrap_torch.so

# Env variables for symmetric heap allocation.
# These are needed for supporting CUDA_VISIBLE DEVICES
# This is big enough for llama3 8b, but should be set correctly
export NVSHMEM_SYMMETRIC_SIZE=$((8*1024**3))
export NVSHMEM_DISABLE_CUDA_VMM=1 # moving from cpp to shell

# Not sure if these are needed
export CUDA_DEVICE_MAX_CONNECTIONS=1
export BYTED_TORCH_BYTECCL=O0
export NCCL_IB_TIMEOUT=${NCCL_IB_TIMEOUT:=23}
export NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX:=3}
export NVSHMEM_IB_GID_INDEX=3
5 changes: 5 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch

import flux
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
Expand Down Expand Up @@ -200,6 +201,10 @@ def __init__(
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator

# Initialize pynvshmem
if torch.distributed.get_world_size(self.device_group) > 1:
flux.init_flux_shm(self.device_group)

# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
Expand Down
152 changes: 139 additions & 13 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple

import flux
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
Expand All @@ -10,6 +11,7 @@
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
Expand Down Expand Up @@ -135,6 +137,104 @@ def apply(self,
return F.linear(x, layer.weight, bias)


class GemmRS(LinearMethodBase):
#Fused Gemm-ReduceScatter without quantization.

def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
output_size, # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction=False,
)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None

output = self.gemm_rs_op.forward(x, layer.weight)
output = output.squeeze(0)

return output


class AGCook(LinearMethodBase):
#Fused AllGather-Gemm without quantization.

def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
weight.shape[0], # N
weight.shape[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
torch.float16,
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy=False,
)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None

output = self.ag_gemm_op.forward(x, layer.weight)

return output


class LinearBase(torch.nn.Module):
"""Base linear layer.

Expand All @@ -155,6 +255,8 @@ def __init__(
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
fuse_gemm_rs: bool = False,
fuse_ag_gemm: bool = False,
):
super().__init__()

Expand All @@ -165,9 +267,15 @@ def __init__(
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()

if fuse_gemm_rs:
assert (quant_config is None)
self.quant_method: Optional[QuantizeMethodBase] = GemmRS()
elif fuse_ag_gemm:
assert (quant_config is None)
self.quant_method = AGCook()
elif quant_config is None:
self.quant_method = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
Expand Down Expand Up @@ -280,9 +388,15 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
prefix: str = "",
fuse_ag_gemm: bool = False):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
fuse_ag_gemm=fuse_ag_gemm)

self.gather_output = gather_output

Expand Down Expand Up @@ -413,7 +527,8 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
fuse_ag_gemm: bool = False):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
Expand All @@ -424,7 +539,8 @@ def __init__(self,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
prefix=prefix,
fuse_ag_gemm=fuse_ag_gemm)

def weight_loader(self,
param: Parameter,
Expand Down Expand Up @@ -654,7 +770,8 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
fuse_ag_gemm: bool = False):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
Expand Down Expand Up @@ -687,7 +804,8 @@ def __init__(self,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
prefix=prefix,
fuse_ag_gemm=fuse_ag_gemm)

def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
Expand Down Expand Up @@ -967,12 +1085,20 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
prefix: str = "",
fuse_gemm_rs: bool = False):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
fuse_gemm_rs=fuse_gemm_rs)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if fuse_gemm_rs:
self.reduce_results = False

# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
Expand Down
Loading
Loading