Skip to content

Commit 8274ca2

Browse files
authored
Add docstrings for LLM (#137)
1 parent 62ec38e commit 8274ca2

File tree

4 files changed

+66
-10
lines changed

4 files changed

+66
-10
lines changed

benchmarks/benchmark_latency.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ def main(args: argparse.Namespace):
3030
max_tokens=args.output_len,
3131
)
3232
print(sampling_params)
33-
dummy_prompts = [""] * args.batch_size
3433
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
3534

3635
def run_to_completion(profile: bool = False):
3736
if profile:
3837
torch.cuda.cudart().cudaProfilerStart()
3938
start_time = time.time()
4039

41-
llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids,
40+
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
41+
sampling_params=sampling_params,
4242
use_tqdm=False)
4343

4444
end_time = time.time()

benchmarks/benchmark_throughput.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def main(args: argparse.Namespace):
7272
)
7373
# FIXME(woosuk): Do not use internal method.
7474
llm._add_request(
75-
prompt="",
76-
sampling_params=sampling_params,
75+
prompt=None,
7776
prompt_token_ids=prompt_token_ids,
77+
sampling_params=sampling_params,
7878
)
7979

8080
start = time.time()
@@ -85,7 +85,9 @@ def main(args: argparse.Namespace):
8585
len(prompt_token_ids) + output_len
8686
for prompt_token_ids, output_len in requests
8787
)
88-
print(f"Throughput: {total_num_tokens / (end - start):.2f} tokens/s")
88+
elapsed_time = end - start
89+
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
90+
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
8991

9092

9193
if __name__ == "__main__":

cacheflow/entrypoints/llm.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,28 @@
1111

1212

1313
class LLM:
14+
"""An LLM for generating texts from given prompts and sampling parameters.
15+
16+
This class includes a tokenizer, a language model (possibly distributed
17+
across multiple GPUs), and GPU memory space allocated for intermediate
18+
states (aka KV cache). Given a batch of prompts and sampling parameters,
19+
this class generates texts from the model, using an intelligent batching
20+
mechanism and efficient memory management.
21+
22+
NOTE: This class is intended to be used for offline inference. For online
23+
serving, use the `AsyncLLMServer` class instead.
24+
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
25+
26+
Args:
27+
model: The name or path of a HuggingFace Transformers model.
28+
tensor_parallel_size: The number of GPUs to use for distributed
29+
execution with tensor parallelism.
30+
dtype: The data type for the model weights and activations. Currently,
31+
we support `float16` and `bfloat16`. If `default`, we use the
32+
`torch_dtype` attribute of the model config. If the `torch_dtype`
33+
is `float32`, we use `float16` instead.
34+
seed: The seed to initialize the random number generator for sampling.
35+
"""
1436

1537
def __init__(
1638
self,
@@ -39,19 +61,50 @@ def get_tokenizer(
3961

4062
def generate(
4163
self,
42-
prompts: Union[str, List[str]],
64+
prompts: Optional[Union[str, List[str]]] = None,
4365
sampling_params: Optional[SamplingParams] = None,
4466
prompt_token_ids: Optional[List[List[int]]] = None,
4567
use_tqdm: bool = True,
4668
) -> List[RequestOutput]:
69+
"""Generates the completions for the input prompts.
70+
71+
NOTE: This class automatically batches the given prompts, considering
72+
the memory constraint. For the best performance, put all of your prompts
73+
into a single list and pass it to this method.
74+
75+
Args:
76+
prompts: A list of prompts to generate completions for.
77+
sampling_params: The sampling parameters for text generation. If
78+
None, we use the default sampling parameters.
79+
prompt_token_ids: A list of token IDs for the prompts. If None, we
80+
use the tokenizer to convert the prompts to token IDs.
81+
use_tqdm: Whether to use tqdm to display the progress bar.
82+
83+
Returns:
84+
A list of `RequestOutput` objects containing the generated
85+
completions in the same order as the input prompts.
86+
"""
87+
if prompts is None and prompt_token_ids is None:
88+
raise ValueError("Either prompts or prompt_token_ids must be "
89+
"provided.")
4790
if isinstance(prompts, str):
91+
# Convert a single prompt to a list.
4892
prompts = [prompts]
93+
if prompts is not None and prompt_token_ids is not None:
94+
if len(prompts) != len(prompt_token_ids):
95+
raise ValueError("The lengths of prompts and prompt_token_ids "
96+
"must be the same.")
4997
if sampling_params is None:
5098
# Use default sampling params.
5199
sampling_params = SamplingParams()
100+
52101
# Add requests to the server.
53-
for i in range(len(prompts)):
54-
prompt = prompts[i]
102+
if prompts is not None:
103+
num_requests = len(prompts)
104+
else:
105+
num_requests = len(prompt_token_ids)
106+
for i in range(num_requests):
107+
prompt = prompts[i] if prompts is not None else None
55108
if prompt_token_ids is None:
56109
token_ids = None
57110
else:
@@ -61,7 +114,7 @@ def generate(
61114

62115
def _add_request(
63116
self,
64-
prompt: str,
117+
prompt: Optional[str],
65118
sampling_params: SamplingParams,
66119
prompt_token_ids: Optional[List[int]],
67120
) -> None:

cacheflow/server/llm_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,15 @@ def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
126126
def add_request(
127127
self,
128128
request_id: str,
129-
prompt: str,
129+
prompt: Optional[str],
130130
sampling_params: SamplingParams,
131131
prompt_token_ids: Optional[List[int]] = None,
132132
arrival_time: Optional[float] = None,
133133
) -> None:
134134
if arrival_time is None:
135135
arrival_time = time.time()
136136
if prompt_token_ids is None:
137+
assert prompt is not None
137138
prompt_token_ids = self.tokenizer.encode(prompt)
138139

139140
# Create the sequences.

0 commit comments

Comments
 (0)