Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/extensions/fastsafetensor.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Loading Model weights with fastsafetensors
===================================================================

Using fastsafetensor library enables loading model weights to GPU memory by leveraging GPU direct storage. See https://github.com/foundation-model-stack/fastsafetensors for more details.
For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true``
1 change: 1 addition & 0 deletions docs/source/models/extensions/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@

runai_model_streamer
tensorizer
fastsafetensor
:::
1 change: 1 addition & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ tritonclient==2.51.0
numpy < 2.0.0
runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
13 changes: 12 additions & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ click==8.1.7
# jiwer
# nltk
# ray
# typer
colorama==0.4.6
# via
# awscli
Expand Down Expand Up @@ -122,6 +123,8 @@ fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
# via cupy-cuda12x
fastsafetensors==0.1.10
# via -r requirements/test.in
filelock==3.16.1
# via
# datasets
Expand Down Expand Up @@ -505,7 +508,9 @@ requests==2.32.3
responses==0.25.3
# via genai-perf
rich==13.9.4
# via genai-perf
# via
# genai-perf
# typer
rouge-score==0.1.2
# via lm-eval
rpds-py==0.20.1
Expand Down Expand Up @@ -550,6 +555,8 @@ setuptools==75.8.0
# via
# pytablewriter
# torch
shellingham==1.5.4
# via typer
six==1.16.0
# via
# python-dateutil
Expand Down Expand Up @@ -600,6 +607,7 @@ torch==2.6.0
# accelerate
# bitsandbytes
# encodec
# fastsafetensors
# lm-eval
# peft
# runai-model-streamer
Expand Down Expand Up @@ -654,6 +662,8 @@ typepy==1.3.2
# dataproperty
# pytablewriter
# tabledata
typer==0.15.2
# via fastsafetensors
typing-extensions==4.12.2
# via
# huggingface-hub
Expand All @@ -663,6 +673,7 @@ typing-extensions==4.12.2
# pydantic
# pydantic-core
# torch
# typer
tzdata==2024.2
# via pandas
urllib3==2.2.3
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ def _read_requirements(filename: str) -> list[str]:
install_requires=get_requirements(),
extras_require={
"tensorizer": ["tensorizer>=2.9.0"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing
"video": ["decord"] # Required for video processing
Expand Down
Empty file.
22 changes: 22 additions & 0 deletions tests/fastsafetensors_loader/test_fastsafetensors_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0

from vllm import SamplingParams
from vllm.config import LoadFormat

test_model = "openai-community/gpt2"

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)


def test_model_loader_download_files(vllm_runner):
with vllm_runner(test_model,
load_format=LoadFormat.FASTSAFETENSORS) as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs
46 changes: 46 additions & 0 deletions tests/fastsafetensors_loader/test_weight_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0

import glob
import tempfile

import huggingface_hub.constants
import torch

from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, fastsafetensors_weights_iterator,
safetensors_weights_iterator)


def test_fastsafetensors_model_loader():
with tempfile.TemporaryDirectory() as tmpdir:
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("openai-community/gpt2",
allow_patterns=["*.safetensors"],
cache_dir=tmpdir)
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
assert len(safetensors) > 0

fastsafetensors_tensors = {}
hf_safetensors_tensors = {}

for name, tensor in fastsafetensors_weights_iterator(
safetensors, True):
fastsafetensors_tensors[name] = tensor

for name, tensor in safetensors_weights_iterator(safetensors, True):
hf_safetensors_tensors[name] = tensor

assert len(fastsafetensors_tensors) == len(hf_safetensors_tensors)

for name, fastsafetensors_tensor in fastsafetensors_tensors.items():
fastsafetensors_tensor = fastsafetensors_tensor.to('cpu')
assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[
name].dtype
assert fastsafetensors_tensor.shape == hf_safetensors_tensors[
name].shape
assert torch.all(
fastsafetensors_tensor.eq(hf_safetensors_tensors[name]))


if __name__ == "__main__":
test_fastsafetensors_model_loader()
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
FASTSAFETENSORS = "fastsafetensors"


@dataclass
Expand Down
24 changes: 16 additions & 8 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, get_gguf_extra_tensor_names,
get_lock, gguf_quant_weights_iterator, initialize_dummy_weights,
np_cache_weights_iterator, pt_weights_iterator,
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -275,7 +276,8 @@ def _prepare_weights(
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS:
elif (load_format == LoadFormat.SAFETENSORS
or load_format == LoadFormat.FASTSAFETENSORS):
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
Expand Down Expand Up @@ -357,10 +359,16 @@ def _get_weights_iterator(
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
Expand Down
47 changes: 47 additions & 0 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
"SafetensorsStreamer")

try:
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
except ImportError:
fastsafetensors = PlaceholderModule("fastsafetensors")
SafeTensorsFileLoader = fastsafetensors.placeholder_attr(
"SafeTensorsFileLoader")
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")

logger = init_logger(__name__)

# use system-level temp directory for file locks, so that multiple users
Expand Down Expand Up @@ -452,6 +460,45 @@ def runai_safetensors_weights_iterator(
yield from streamer.get_tensors()


def fastsafetensors_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
if torch.distributed.is_initialized():
pg = torch.distributed.group.WORLD
else:
pg = SingleGroup()

device = torch.device(f'cuda:{pg.rank()}')
weight_files_sub_lists = [
hf_weights_files[i:i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())
]

for f_list in tqdm(
weight_files_sub_lists,
desc="Loading safetensors using Fastsafetensor loader",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
loader = SafeTensorsFileLoader(pg, device)
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()


def pt_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,
Expand Down