1111
1212
1313class LLM :
14+ """An LLM for generating texts from given prompts and sampling parameters.
15+
16+ This class includes a tokenizer, a language model (possibly distributed
17+ across multiple GPUs), and GPU memory space allocated for intermediate
18+ states (aka KV cache). Given a batch of prompts and sampling parameters,
19+ this class generates texts from the model, using an intelligent batching
20+ mechanism and efficient memory management.
21+
22+ NOTE: This class is intended to be used for offline inference. For online
23+ serving, use the `AsyncLLMServer` class instead.
24+ NOTE: For the comprehensive list of arguments, see `ServerArgs`.
25+
26+ Args:
27+ model: The name or path of a HuggingFace Transformers model.
28+ tensor_parallel_size: The number of GPUs to use for distributed
29+ execution with tensor parallelism.
30+ dtype: The data type for the model weights and activations. Currently,
31+ we support `float16` and `bfloat16`. If `default`, we use the
32+ `torch_dtype` attribute of the model config. If the `torch_dtype`
33+ is `float32`, we use `float16` instead.
34+ seed: The seed to initialize the random number generator for sampling.
35+ """
1436
1537 def __init__ (
1638 self ,
@@ -39,19 +61,50 @@ def get_tokenizer(
3961
4062 def generate (
4163 self ,
42- prompts : Union [str , List [str ]],
64+ prompts : Optional [ Union [str , List [str ]]] = None ,
4365 sampling_params : Optional [SamplingParams ] = None ,
4466 prompt_token_ids : Optional [List [List [int ]]] = None ,
4567 use_tqdm : bool = True ,
4668 ) -> List [RequestOutput ]:
69+ """Generates the completions for the input prompts.
70+
71+ NOTE: This class automatically batches the given prompts, considering
72+ the memory constraint. For the best performance, put all of your prompts
73+ into a single list and pass it to this method.
74+
75+ Args:
76+ prompts: A list of prompts to generate completions for.
77+ sampling_params: The sampling parameters for text generation. If
78+ None, we use the default sampling parameters.
79+ prompt_token_ids: A list of token IDs for the prompts. If None, we
80+ use the tokenizer to convert the prompts to token IDs.
81+ use_tqdm: Whether to use tqdm to display the progress bar.
82+
83+ Returns:
84+ A list of `RequestOutput` objects containing the generated
85+ completions in the same order as the input prompts.
86+ """
87+ if prompts is None and prompt_token_ids is None :
88+ raise ValueError ("Either prompts or prompt_token_ids must be "
89+ "provided." )
4790 if isinstance (prompts , str ):
91+ # Convert a single prompt to a list.
4892 prompts = [prompts ]
93+ if prompts is not None and prompt_token_ids is not None :
94+ if len (prompts ) != len (prompt_token_ids ):
95+ raise ValueError ("The lengths of prompts and prompt_token_ids "
96+ "must be the same." )
4997 if sampling_params is None :
5098 # Use default sampling params.
5199 sampling_params = SamplingParams ()
100+
52101 # Add requests to the server.
53- for i in range (len (prompts )):
54- prompt = prompts [i ]
102+ if prompts is not None :
103+ num_requests = len (prompts )
104+ else :
105+ num_requests = len (prompt_token_ids )
106+ for i in range (num_requests ):
107+ prompt = prompts [i ] if prompts is not None else None
55108 if prompt_token_ids is None :
56109 token_ids = None
57110 else :
@@ -61,7 +114,7 @@ def generate(
61114
62115 def _add_request (
63116 self ,
64- prompt : str ,
117+ prompt : Optional [ str ] ,
65118 sampling_params : SamplingParams ,
66119 prompt_token_ids : Optional [List [int ]],
67120 ) -> None :
0 commit comments