diff --git a/doc/BUILD.bazel b/doc/BUILD.bazel index 2dcca3b6ac1e..2455f640506a 100644 --- a/doc/BUILD.bazel +++ b/doc/BUILD.bazel @@ -333,6 +333,30 @@ py_test_run_all_subdirectory( ], ) +# -------------------------------------------------------------------- +# Test all doc/source/data/doc_code/working-with-llms code included in rst/md files. +# -------------------------------------------------------------------- + +filegroup( + name = "data_llm_examples", + srcs = glob(["source/data/doc_code/working-with-llms/**/*.py"]), + visibility = ["//doc:__subpackages__"], +) + +# GPU Tests +py_test_run_all_subdirectory( + size = "large", + include = ["source/data/doc_code/working-with-llms/**/*.py"], + exclude = [], + extra_srcs = [], + tags = [ + "exclusive", + "gpu", + "team:data", + "team:llm" + ], +) + # -------------------------------------------------------------------- # Test all doc/source/tune/doc_code code included in rst/md files. # -------------------------------------------------------------------- @@ -527,8 +551,6 @@ doctest_each( # These tests run on GPU (see below). "source/data/batch_inference.rst", "source/data/transforming-data.rst", - # These tests are currently failing. - "source/data/working-with-llms.rst", # These don't contain code snippets. "source/data/api/**/*.rst", ], diff --git a/doc/source/data/doc_code/working-with-llms/basic_llm_example.py b/doc/source/data/doc_code/working-with-llms/basic_llm_example.py new file mode 100644 index 000000000000..f1f54d60fdb2 --- /dev/null +++ b/doc/source/data/doc_code/working-with-llms/basic_llm_example.py @@ -0,0 +1,200 @@ +""" +This file serves as a documentation example and CI test for basic LLM batch inference. + +""" + +# Dependency setup +import subprocess +import sys + +subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "ray[llm]"]) +subprocess.check_call( + [sys.executable, "-m", "pip", "install", "--upgrade", "transformers"] +) +subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.26.4"]) + + +# __basic_llm_example_start__ +import ray +from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor + +# __basic_config_example_start__ +# Basic vLLM configuration +config = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs={ + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4096, # Reduce if CUDA OOM occurs + "max_model_len": 16384, + }, + concurrency=1, + batch_size=64, +) +# __basic_config_example_end__ + +processor = build_llm_processor( + config, + preprocess=lambda row: dict( + messages=[ + {"role": "system", "content": "You are a bot that responds with haikus."}, + {"role": "user", "content": row["item"]}, + ], + sampling_params=dict( + temperature=0.3, + max_tokens=250, + ), + ), + postprocess=lambda row: dict( + answer=row["generated_text"], + **row, # This will return all the original columns in the dataset. + ), +) + +ds = ray.data.from_items(["Start of the haiku is: Complete this for me..."]) + +if __name__ == "__main__": + try: + import torch + + if torch.cuda.is_available(): + ds = processor(ds) + ds.show(limit=1) + else: + print("Skipping basic LLM run (no GPU available)") + except Exception as e: + print(f"Skipping basic LLM run due to environment error: {e}") + +# __hf_token_config_example_start__ +# Configuration with Hugging Face token +config_with_token = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + runtime_env={"env_vars": {"HF_TOKEN": "your_huggingface_token"}}, + concurrency=1, + batch_size=64, +) +# __hf_token_config_example_end__ + +# __parallel_config_example_start__ +# Model parallelism configuration for larger models +# tensor_parallel_size=2: Split model across 2 GPUs for tensor parallelism +# pipeline_parallel_size=2: Use 2 pipeline stages (total 4 GPUs needed) +# Total GPUs required = tensor_parallel_size * pipeline_parallel_size = 4 +config = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs={ + "max_model_len": 16384, + "tensor_parallel_size": 2, + "pipeline_parallel_size": 2, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 2048, + }, + concurrency=1, + batch_size=32, + accelerator_type="L4", +) +# __parallel_config_example_end__ + +# __runai_config_example_start__ +# RunAI streamer configuration for optimized model loading +# Note: Install vLLM with runai dependencies: pip install -U "vllm[runai]>=0.10.1" +config = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs={ + "load_format": "runai_streamer", + "max_model_len": 16384, + }, + concurrency=1, + batch_size=64, +) +# __runai_config_example_end__ + +# __lora_config_example_start__ +# Multi-LoRA configuration +config = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs={ + "enable_lora": True, + "max_lora_rank": 32, + "max_loras": 1, + "max_model_len": 16384, + }, + concurrency=1, + batch_size=32, +) +# __lora_config_example_end__ + +# __s3_config_example_start__ +# S3 hosted model configuration +s3_config = vLLMEngineProcessorConfig( + model_source="s3://your-bucket/your-model-path/", + engine_kwargs={ + "load_format": "runai_streamer", + "max_model_len": 16384, + }, + concurrency=1, + batch_size=64, +) +# __s3_config_example_end__ + +# __gpu_memory_config_example_start__ +# GPU memory management configuration +# If you encounter CUDA out of memory errors, try these optimizations: +config_memory_optimized = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs={ + "max_model_len": 8192, + "max_num_batched_tokens": 2048, + "enable_chunked_prefill": True, + "gpu_memory_utilization": 0.85, + "block_size": 16, + }, + concurrency=1, + batch_size=16, +) + +# For very large models or limited GPU memory: +config_minimal_memory = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs={ + "max_model_len": 4096, + "max_num_batched_tokens": 1024, + "enable_chunked_prefill": True, + "gpu_memory_utilization": 0.75, + }, + concurrency=1, + batch_size=8, +) +# __gpu_memory_config_example_end__ + +# __embedding_config_example_start__ +# Embedding model configuration +embedding_config = vLLMEngineProcessorConfig( + model_source="sentence-transformers/all-MiniLM-L6-v2", + task_type="embed", + engine_kwargs=dict( + enable_prefix_caching=False, + enable_chunked_prefill=False, + max_model_len=256, + enforce_eager=True, + ), + batch_size=32, + concurrency=1, + apply_chat_template=False, + detokenize=False, +) + +# Example usage for embeddings +def create_embedding_processor(): + return build_llm_processor( + embedding_config, + preprocess=lambda row: dict(prompt=row["text"]), + postprocess=lambda row: { + "text": row["prompt"], + "embedding": row["embeddings"], + }, + ) + + +# __embedding_config_example_end__ + +# __basic_llm_example_end__ diff --git a/doc/source/data/doc_code/working-with-llms/embedding_example.py b/doc/source/data/doc_code/working-with-llms/embedding_example.py new file mode 100644 index 000000000000..b1bf2e46b10a --- /dev/null +++ b/doc/source/data/doc_code/working-with-llms/embedding_example.py @@ -0,0 +1,63 @@ +""" +Documentation example and test for embedding model batch inference. + +""" + +import subprocess +import sys + +subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "ray[llm]"]) +subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.26.4"]) + + +def run_embedding_example(): + # __embedding_example_start__ + import ray + from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor + + embedding_config = vLLMEngineProcessorConfig( + model_source="sentence-transformers/all-MiniLM-L6-v2", + task_type="embed", + engine_kwargs=dict( + enable_prefix_caching=False, + enable_chunked_prefill=False, + max_model_len=256, + enforce_eager=True, + ), + batch_size=32, + concurrency=1, + apply_chat_template=False, + detokenize=False, + ) + + embedding_processor = build_llm_processor( + embedding_config, + preprocess=lambda row: dict(prompt=row["text"]), + postprocess=lambda row: { + "text": row["prompt"], + "embedding": row["embeddings"], + }, + ) + + texts = [ + "Hello world", + "This is a test sentence", + "Embedding models convert text to vectors", + ] + ds = ray.data.from_items([{"text": text} for text in texts]) + + embedded_ds = embedding_processor(ds) + embedded_ds.show(limit=1) + # __embedding_example_end__ + + +if __name__ == "__main__": + try: + import torch + + if torch.cuda.is_available(): + run_embedding_example() + else: + print("Skipping embedding example (no GPU available)") + except Exception as e: + print(f"Skipping embedding example: {e}") diff --git a/doc/source/data/doc_code/working-with-llms/openai_api_example.py b/doc/source/data/doc_code/working-with-llms/openai_api_example.py new file mode 100644 index 000000000000..6c707984d79f --- /dev/null +++ b/doc/source/data/doc_code/working-with-llms/openai_api_example.py @@ -0,0 +1,99 @@ +""" +This file serves as a documentation example and CI test for OpenAI API batch inference. + +""" + +import os +from ray.data.llm import HttpRequestProcessorConfig, build_llm_processor + + +def run_openai_example(): + # __openai_example_start__ + import ray + + OPENAI_KEY = os.environ["OPENAI_API_KEY"] + ds = ray.data.from_items(["Hand me a haiku."]) + + config = HttpRequestProcessorConfig( + url="https://api.openai.com/v1/chat/completions", + headers={"Authorization": f"Bearer {OPENAI_KEY}"}, + qps=1, + ) + + processor = build_llm_processor( + config, + preprocess=lambda row: dict( + payload=dict( + model="gpt-4o-mini", + messages=[ + { + "role": "system", + "content": "You are a bot that responds with haikus.", + }, + {"role": "user", "content": row["item"]}, + ], + temperature=0.0, + max_tokens=150, + ), + ), + postprocess=lambda row: dict( + response=row["http_response"]["choices"][0]["message"]["content"] + ), + ) + + ds = processor(ds) + print(ds.take_all()) + # __openai_example_end__ + + +def run_openai_demo(): + """Run the OpenAI API configuration demo.""" + print("OpenAI API Configuration Demo") + print("=" * 30) + print("\nExample configuration:") + print("config = HttpRequestProcessorConfig(") + print(" url='https://api.openai.com/v1/chat/completions',") + print(" headers={'Authorization': f'Bearer {OPENAI_KEY}'},") + print(" qps=1,") + print(")") + print("\nThe processor handles:") + print("- Preprocessing: Convert text to OpenAI API format") + print("- HTTP requests: Send batched requests to OpenAI") + print("- Postprocessing: Extract response content") + + +def preprocess_for_openai(row): + """Preprocess function for OpenAI API requests.""" + return dict( + payload=dict( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": row["item"]}, + ], + temperature=0.0, + max_tokens=150, + ) + ) + + +def postprocess_openai_response(row): + """Postprocess function for OpenAI API responses.""" + return dict(response=row["http_response"]["choices"][0]["message"]["content"]) + + +if __name__ == "__main__": + # Run live call if API key is set; otherwise show demo with mock output + if "OPENAI_API_KEY" in os.environ: + run_openai_example() + else: + # Mock response without API key + print( + [ + { + "response": ( + "Autumn leaves whisper\nSoft code flows in quiet lines\nBugs fall one by one" + ) + } + ] + ) diff --git a/doc/source/data/doc_code/working-with-llms/vlm_example.py b/doc/source/data/doc_code/working-with-llms/vlm_example.py new file mode 100644 index 000000000000..cbfb3d2de4fa --- /dev/null +++ b/doc/source/data/doc_code/working-with-llms/vlm_example.py @@ -0,0 +1,215 @@ +""" +This file serves as a documentation example and CI test for VLM batch inference. + +Structure: +1. Infrastructure setup: Dataset compatibility patches, dependency handling +2. Docs example (between __vlm_example_start/end__): Embedded in Sphinx docs via literalinclude +3. Test validation and cleanup +""" + +import subprocess +import sys + +# Dependency setup +subprocess.check_call( + [sys.executable, "-m", "pip", "install", "--upgrade", "transformers", "datasets"] +) +subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "ray[llm]"]) +subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.26.4"]) + + +# __vlm_example_start__ +import ray +from PIL import Image +from io import BytesIO +from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor + +# Load "LMMs-Eval-Lite" dataset from Hugging Face +import datasets as datasets_lib + +vision_dataset_llms_lite = datasets_lib.load_dataset( + "lmms-lab/LMMs-Eval-Lite", "coco2017_cap_val" +) +vision_dataset = ray.data.from_huggingface(vision_dataset_llms_lite["lite"]) + +HF_TOKEN = "your-hf-token-here" # Replace with actual token if needed + +# __vlm_config_example_start__ +vision_processor_config = vLLMEngineProcessorConfig( + model_source="Qwen/Qwen2.5-VL-3B-Instruct", + engine_kwargs=dict( + tensor_parallel_size=1, + pipeline_parallel_size=1, + max_model_len=4096, + enable_chunked_prefill=True, + max_num_batched_tokens=2048, + ), + # Override Ray's runtime env to include the Hugging Face token. Ray Data uses Ray under the hood to orchestrate the inference pipeline. + runtime_env=dict( + env_vars=dict( + # HF_TOKEN=HF_TOKEN, # Token not needed for public models + VLLM_USE_V1="1", + ), + ), + batch_size=16, + accelerator_type="L4", + concurrency=1, + has_image=True, +) +# __vlm_config_example_end__ + + +def vision_preprocess(row: dict) -> dict: + """ + Preprocessing function for vision-language model inputs. + + Converts dataset rows into the format expected by the VLM: + - System prompt for analysis instructions + - User message with text and image content + - Multiple choice formatting + - Sampling parameters + """ + choice_indices = ["A", "B", "C", "D", "E", "F", "G", "H"] + + return { + "messages": [ + { + "role": "system", + "content": ( + "Analyze the image and question carefully, using step-by-step reasoning. " + "First, describe any image provided in detail. Then, present your reasoning. " + "And finally your final answer in this format: Final Answer: " + "where is: The single correct letter choice A, B, C, D, E, F, etc. when options are provided. " + "Only include the letter. Your direct answer if no options are given, as a single phrase or number. " + "IMPORTANT: Remember, to end your answer with Final Answer: ." + ), + }, + { + "role": "user", + "content": [ + {"type": "text", "text": row["question"] + "\n\n"}, + { + "type": "image", + # Ray Data accepts PIL Image or image URL + "image": Image.open(BytesIO(row["image"]["bytes"])), + }, + { + "type": "text", + "text": "\n\nChoices:\n" + + "\n".join( + [ + f"{choice_indices[i]}. {choice}" + for i, choice in enumerate(row["answer"]) + ] + ), + }, + ], + }, + ], + "sampling_params": { + "temperature": 0.3, + "max_tokens": 150, + "detokenize": False, + }, + # Include original data for reference + "original_data": { + "question": row["question"], + "answer_choices": row["answer"], + "image_size": row["image"].get("width", 0) if row["image"] else 0, + }, + } + + +def vision_postprocess(row: dict) -> dict: + return { + "resp": row["generated_text"], + } + + +vision_processor = build_llm_processor( + vision_processor_config, + preprocess=vision_preprocess, + postprocess=vision_postprocess, +) + + +def load_vision_dataset(): + """ + Load vision dataset from Hugging Face. + + This function loads the LMMs-Eval-Lite dataset which contains: + - Images with associated questions + - Multiple choice answers + - Various visual reasoning tasks + """ + try: + import datasets + + # Load "LMMs-Eval-Lite" dataset from Hugging Face + vision_dataset_llms_lite = datasets.load_dataset( + "lmms-lab/LMMs-Eval-Lite", "coco2017_cap_val" + ) + vision_dataset = ray.data.from_huggingface(vision_dataset_llms_lite["lite"]) + + return vision_dataset + except ImportError: + print( + "datasets package not available. Install with: pip install datasets>=4.0.0" + ) + return None + except Exception as e: + print(f"Error loading dataset: {e}") + return None + + +def create_vlm_config(): + """Create VLM configuration.""" + return vLLMEngineProcessorConfig( + model_source="Qwen/Qwen2.5-VL-3B-Instruct", + engine_kwargs=dict( + tensor_parallel_size=1, + pipeline_parallel_size=1, + max_model_len=4096, + trust_remote_code=True, + limit_mm_per_prompt={"image": 1}, + ), + runtime_env={ + # "env_vars": {"HF_TOKEN": "your-hf-token-here"} # Token not needed for public models + }, + batch_size=1, + accelerator_type="L4", + concurrency=1, + has_image=True, + ) + + +def run_vlm_example(): + """Run the complete VLM example workflow.""" + config = create_vlm_config() + vision_dataset = load_vision_dataset() + + if vision_dataset: + # Build processor with preprocessing + processor = build_llm_processor(config, preprocess=vision_preprocess) + + print("VLM processor configured successfully") + print(f"Model: {config.model_source}") + print(f"Has image support: {config.has_image}") + result = processor(vision_dataset).take_all() + return config, processor, result + return None, None, None + + +# __vlm_example_end__ + +if __name__ == "__main__": + # Run the example VLM workflow only if GPU is available + try: + import torch + + if torch.cuda.is_available(): + run_vlm_example() + else: + print("Skipping VLM example run (no GPU available)") + except Exception as e: + print(f"Skipping VLM example run due to environment error: {e}") diff --git a/doc/source/data/working-with-llms.rst b/doc/source/data/working-with-llms.rst index cfd0c4bedf77..cfa6a5ef8019 100644 --- a/doc/source/data/working-with-llms.rst +++ b/doc/source/data/working-with-llms.rst @@ -23,83 +23,50 @@ logic for performing batch inference with LLMs on a Ray Data dataset. You can use the :func:`build_llm_processor ` API to construct a processor. The following example uses the :class:`vLLMEngineProcessorConfig ` to construct a processor for the `unsloth/Llama-3.1-8B-Instruct` model. -To run this example, install vLLM, which is a popular and optimized LLM inference engine. +To start, install Ray Data + LLMs. This also installs vLLM, which is a popular and optimized LLM inference engine. -.. testcode:: +.. code-block:: bash - # Later versions *should* work but are not tested yet. - pip install -U vllm==0.7.2 + pip install -U "ray[data, llm]>=2.49.1" The :class:`vLLMEngineProcessorConfig ` is a configuration object for the vLLM engine. It contains the model name, the number of GPUs to use, and the number of shards to use, along with other vLLM engine configurations. Upon execution, the Processor object instantiates replicas of the vLLM engine (using :meth:`map_batches ` under the hood). -.. testcode:: - - import ray - from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor - - config = vLLMEngineProcessorConfig( - model_source="unsloth/Llama-3.1-8B-Instruct", - engine_kwargs={ - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4096, - "max_model_len": 16384, - }, - concurrency=1, - batch_size=64, - ) - processor = build_llm_processor( - config, - preprocess=lambda row: dict( - messages=[ - {"role": "system", "content": "You are a bot that responds with haikus."}, - {"role": "user", "content": row["item"]} - ], - sampling_params=dict( - temperature=0.3, - max_tokens=250, - ) - ), - postprocess=lambda row: dict( - answer=row["generated_text"], - **row # This will return all the original columns in the dataset. - ), - ) - - ds = ray.data.from_items(["Start of the haiku is: Complete this for me..."]) - - ds = processor(ds) - ds.show(limit=1) +.. .. literalinclude:: doc_code/working-with-llms/basic_llm_example.py +.. :language: python +.. :start-after: __basic_llm_example_start__ +.. :end-before: __basic_llm_example_end__ -.. testoutput:: - :options: +MOCK +Here's a simple configuration example: - {'answer': 'Snowflakes gently fall\nBlanketing the winter scene\nFrozen peaceful hush'} +.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py + :language: python + :start-after: __basic_config_example_start__ + :end-before: __basic_config_example_end__ -Each processor requires specific input columns. You can find more info by using the following API: +The configuration includes detailed comments explaining: -.. testcode:: +- **`concurrency`**: Number of vLLM engine replicas (typically 1 per node) +- **`batch_size`**: Number of samples processed per batch (reduce if GPU memory is limited) +- **`max_num_batched_tokens`**: Maximum tokens processed simultaneously (reduce if CUDA OOM occurs) +- **`accelerator_type`**: Specify GPU type for optimal resource allocation - processor.log_input_column_names() +Each processor requires specific input columns based on the model and configuration. The vLLM processor expects input in OpenAI chat format with a 'messages' column. -.. testoutput:: - :options: +MOCK +This basic configuration pattern is used throughout this guide and includes helpful comments explaining key parameters. - The first stage of the processor is ChatTemplateStage. - Required input columns: - messages: A list of messages in OpenAI chat format. See https://platform.openai.com/docs/api-reference/chat/create for details. +This configuration creates a processor that expects: -Some models may require a Hugging Face token to be specified. You can specify the token in the `runtime_env` argument. +- **Input**: Dataset with 'messages' column (OpenAI chat format) +- **Output**: Dataset with 'generated_text' column containing model responses -.. testcode:: +Some models may require a Hugging Face token to be specified. You can specify the token in the `runtime_env` argument. - config = vLLMEngineProcessorConfig( - model_source="unsloth/Llama-3.1-8B-Instruct", - runtime_env={"env_vars": {"HF_TOKEN": "your_huggingface_token"}}, - concurrency=1, - batch_size=64, - ) +.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py + :language: python + :start-after: __hf_token_config_example_start__ + :end-before: __hf_token_config_example_end__ .. _vllm_llm: @@ -108,33 +75,12 @@ Configure vLLM for LLM inference Use the :class:`vLLMEngineProcessorConfig ` to configure the vLLM engine. -.. testcode:: - - from ray.data.llm import vLLMEngineProcessorConfig - - config = vLLMEngineProcessorConfig( - model_source="unsloth/Llama-3.1-8B-Instruct", - engine_kwargs={"max_model_len": 20000}, - concurrency=1, - batch_size=64, - ) - -For handling larger models, specify model parallelism. +For handling larger models, specify model parallelism: -.. testcode:: - - config = vLLMEngineProcessorConfig( - model_source="unsloth/Llama-3.1-8B-Instruct", - engine_kwargs={ - "max_model_len": 16384, - "tensor_parallel_size": 2, - "pipeline_parallel_size": 2, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 2048, - }, - concurrency=1, - batch_size=64, - ) +.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py + :language: python + :start-after: __parallel_config_example_start__ + :end-before: __parallel_config_example_end__ The underlying :class:`Processor ` object instantiates replicas of the vLLM engine and automatically configure parallel workers to handle model parallelism (for tensor parallelism and pipeline parallelism, @@ -143,47 +89,26 @@ if specified). To optimize model loading, you can configure the `load_format` to `runai_streamer` or `tensorizer`. .. note:: - In this case, install vLLM with runai dependencies: `pip install -U "vllm[runai]==0.7.2"` - -.. testcode:: + In this case, install vLLM with runai dependencies: `pip install -U "vllm[runai]>=0.10.1"` - config = vLLMEngineProcessorConfig( - model_source="unsloth/Llama-3.1-8B-Instruct", - engine_kwargs={"load_format": "runai_streamer"}, - concurrency=1, - batch_size=64, - ) +.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py + :language: python + :start-after: __runai_config_example_start__ + :end-before: __runai_config_example_end__ If your model is hosted on AWS S3, you can specify the S3 path in the `model_source` argument, and specify `load_format="runai_streamer"` in the `engine_kwargs` argument. -.. testcode:: - - config = vLLMEngineProcessorConfig( - model_source="s3://your-bucket/your-model/", # Make sure adding the trailing slash! - engine_kwargs={"load_format": "runai_streamer"}, - runtime_env={"env_vars": { - "AWS_ACCESS_KEY_ID": "your_access_key_id", - "AWS_SECRET_ACCESS_KEY": "your_secret_access_key", - "AWS_REGION": "your_region", - }}, - concurrency=1, - batch_size=64, - ) +.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py + :language: python + :start-after: __s3_config_example_start__ + :end-before: __s3_config_example_end__ To do multi-LoRA batch inference, you need to set LoRA related parameters in `engine_kwargs`. See :doc:`the vLLM with LoRA example` for details. -.. testcode:: - - config = vLLMEngineProcessorConfig( - model_source="unsloth/Llama-3.1-8B-Instruct", - engine_kwargs={ - enable_lora=True, - max_lora_rank=32, - max_loras=1, - }, - concurrency=1, - batch_size=64, - ) +.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py + :language: python + :start-after: __lora_config_example_start__ + :end-before: __lora_config_example_end__ .. _vision_language_model: @@ -199,90 +124,43 @@ This example applies 2 adjustments on top of the previous example: - set `has_image=True` in `vLLMEngineProcessorConfig` - prepare image input inside preprocessor -.. testcode:: - - # Load "LMMs-Eval-Lite" dataset from Hugging Face. - vision_dataset_llms_lite = datasets.load_dataset("lmms-lab/LMMs-Eval-Lite", "coco2017_cap_val") - vision_dataset = ray.data.from_huggingface(vision_dataset_llms_lite["lite"]) - - vision_processor_config = vLLMEngineProcessorConfig( - model_source="Qwen/Qwen2.5-VL-3B-Instruct", - engine_kwargs=dict( - tensor_parallel_size=1, - pipeline_parallel_size=1, - max_model_len=4096, - enable_chunked_prefill=True, - max_num_batched_tokens=2048, - ), - # Override Ray's runtime env to include the Hugging Face token. Ray Data uses Ray under the hood to orchestrate the inference pipeline. - runtime_env=dict( - env_vars=dict( - HF_TOKEN=HF_TOKEN, - VLLM_USE_V1="1", - ), - ), - batch_size=16, - accelerator_type="L4", - concurrency=1, - has_image=True, - ) - - def vision_preprocess(row: dict) -> dict: - choice_indices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] - return dict( - messages=[ - { - "role": "system", - "content": """Analyze the image and question carefully, using step-by-step reasoning. - First, describe any image provided in detail. Then, present your reasoning. And finally your final answer in this format: - Final Answer: - where is: - - The single correct letter choice A, B, C, D, E, F, etc. when options are provided. Only include the letter. - - Your direct answer if no options are given, as a single phrase or number. - - If your answer is a number, only include the number without any unit. - - If your answer is a word or phrase, do not paraphrase or reformat the text you see in the image. - - You cannot answer that the question is unanswerable. You must either pick an option or provide a direct answer. - IMPORTANT: Remember, to end your answer with Final Answer: .""", - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": row["question"] + "\n\n" - }, - { - "type": "image", - # Ray Data accepts PIL Image or image URL. - "image": Image.open(BytesIO(row["image"]["bytes"])) - }, - { - "type": "text", - "text": "\n\nChoices:\n" + "\n".join([f"{choice_indices[i]}. {choice}" for i, choice in enumerate(row["answer"])]) - } - ] - }, - ], - sampling_params=dict( - temperature=0.3, - max_tokens=150, - detokenize=False, - ), - ) - - def vision_postprocess(row: dict) -> dict: - return { - "resp": row["generated_text"], - } - - vision_processor = build_llm_processor( - vision_processor_config, - preprocess=vision_preprocess, - postprocess=vision_postprocess, - ) - - vision_processed_ds = vision_processor(vision_dataset).materialize() - vision_processed_ds.show(3) +First, install the required dependencies: + +.. code-block:: bash + + # Install required dependencies for vision-language models + pip install datasets>=4.0.0 + +First, load a vision dataset: + +.. literalinclude:: doc_code/working-with-llms/vlm_example.py + :language: python + :start-after: def load_vision_dataset(): + :end-before: def create_vlm_config(): + :dedent: 0 + +Next, configure the VLM processor with the essential settings: + +.. literalinclude:: doc_code/working-with-llms/vlm_example.py + :language: python + :start-after: __vlm_config_example_start__ + :end-before: __vlm_config_example_end__ + +For a more comprehensive VLM configuration with advanced options: + +.. literalinclude:: doc_code/working-with-llms/vlm_example.py + :language: python + :start-after: def create_vlm_config(): + :end-before: def run_vlm_example(): + :dedent: 0 + +Finally, run the VLM inference: + +.. literalinclude:: doc_code/working-with-llms/vlm_example.py + :language: python + :start-after: def run_vlm_example(): + :end-before: # __vlm_example_end__ + :dedent: 0 .. _embedding_models: @@ -291,44 +169,10 @@ Batch inference with embedding models Ray Data LLM supports batch inference with embedding models using vLLM: -.. testcode:: - - import ray - from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor - - embedding_config = vLLMEngineProcessorConfig( - model_source="sentence-transformers/all-MiniLM-L6-v2", - task_type="embed", - engine_kwargs=dict( - enable_prefix_caching=False, - enable_chunked_prefill=False, - max_model_len=256, - enforce_eager=True, - ), - batch_size=32, - concurrency=1, - apply_chat_template=False, - detokenize=False, - ) - - embedding_processor = build_llm_processor( - embedding_config, - preprocess=lambda row: dict(prompt=row["text"]), - postprocess=lambda row: { - "text": row["prompt"], - "embedding": row["embeddings"], - }, - ) - - texts = [ - "Hello world", - "This is a test sentence", - "Embedding models convert text to vectors", - ] - ds = ray.data.from_items([{"text": text} for text in texts]) - - embedded_ds = embedding_processor(ds) - embedded_ds.show(limit=1) +.. literalinclude:: doc_code/working-with-llms/embedding_example.py + :language: python + :start-after: __embedding_example_start__ + :end-before: __embedding_example_end__ .. testoutput:: :options: +MOCK @@ -342,6 +186,13 @@ Key differences for embedding models: - Use direct ``prompt`` input instead of ``messages`` - Access embeddings through``row["embeddings"]`` +For a complete embedding configuration example, see: + +.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py + :language: python + :start-after: __embedding_config_example_start__ + :end-before: __embedding_config_example_end__ + .. _openai_compatible_api_endpoint: Batch inference with an OpenAI-compatible endpoint @@ -349,40 +200,10 @@ Batch inference with an OpenAI-compatible endpoint You can also make calls to deployed models that have an OpenAI compatible API endpoint. -.. testcode:: - - import ray - import os - from ray.data.llm import HttpRequestProcessorConfig, build_llm_processor - - OPENAI_KEY = os.environ["OPENAI_API_KEY"] - ds = ray.data.from_items(["Hand me a haiku."]) - - - config = HttpRequestProcessorConfig( - url="https://api.openai.com/v1/chat/completions", - headers={"Authorization": f"Bearer {OPENAI_KEY}"}, - qps=1, - ) - - processor = build_llm_processor( - config, - preprocess=lambda row: dict( - payload=dict( - model="gpt-4o-mini", - messages=[ - {"role": "system", "content": "You are a bot that responds with haikus."}, - {"role": "user", "content": row["item"]} - ], - temperature=0.0, - max_tokens=150, - ), - ), - postprocess=lambda row: dict(response=row["http_response"]["choices"][0]["message"]["content"]), - ) - - ds = processor(ds) - print(ds.take_all()) +.. literalinclude:: doc_code/working-with-llms/openai_api_example.py + :language: python + :start-after: __openai_example_start__ + :end-before: __openai_example_end__ Usage Data Collection -------------------------- @@ -407,6 +228,7 @@ Frequently Asked Questions (FAQs) -------------------------------------------------- .. TODO(#55491): Rewrite this section once the restriction is lifted. +.. TODO(#55405): Cross-node TP in progress. .. _cross_node_parallelism: How to configure LLM stage to parallelize across multiple nodes? @@ -426,6 +248,28 @@ as long as each replica (TP * PP) fits into a single node. The number of replicas is configured by the `concurrency` argument in :class:`vLLMEngineProcessorConfig `. +.. _gpu_memory_management: + +GPU Memory Management and CUDA OOM Prevention +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you encounter CUDA out of memory errors, Ray Data LLM provides several configuration options to optimize GPU memory usage: + +.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py + :language: python + :start-after: __gpu_memory_config_example_start__ + :end-before: __gpu_memory_config_example_end__ + +**Key strategies for handling GPU memory issues:** + +- **Reduce batch size**: Start with smaller batches (8-16) and increase gradually +- **Lower `max_num_batched_tokens`**: Reduce from 4096 to 2048 or 1024 +- **Decrease `max_model_len`**: Use shorter context lengths when possible +- **Set `gpu_memory_utilization`**: Use 0.75-0.85 instead of default 0.90 +- **Use smaller models**: Consider using smaller model variants for resource-constrained environments + +If you run into CUDA out of memory, your batch size is likely too large. Set an explicit small batch size or use a smaller model, or a larger GPU. + .. _model_cache: How to cache model weight to remote object storage @@ -437,7 +281,7 @@ storage (AWS S3 or Google Cloud Storage) for more stable model loading. Ray Data LLM provides the following utility to help uploading models to remote object storage. -.. testcode:: +.. code-block:: bash # Download model from HuggingFace, and upload to GCS python -m ray.llm.utils.upload_model \ @@ -450,9 +294,7 @@ Ray Data LLM provides the following utility to help uploading models to remote o And later you can use remote object store URI as `model_source` in the config. -.. testcode:: - - config = vLLMEngineProcessorConfig( - model_source="gs://my-bucket/path/to/facebook-opt-350m", # or s3://my-bucket/path/to/model_name - ... - ) +.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py + :language: python + :start-after: __s3_config_example_start__ + :end-before: __s3_config_example_end__