Skip to content

Commit d1557e6

Browse files
authored
[Misc] Enhance offline_inference to support user-configurable paramet… (#10392)
Signed-off-by: wchen61 <[email protected]>
1 parent 80d85c5 commit d1557e6

File tree

1 file changed

+78
-20
lines changed

1 file changed

+78
-20
lines changed

examples/offline_inference.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,80 @@
1+
from dataclasses import asdict
2+
13
from vllm import LLM, SamplingParams
4+
from vllm.engine.arg_utils import EngineArgs
5+
from vllm.utils import FlexibleArgumentParser
6+
7+
8+
def get_prompts(num_prompts: int):
9+
# The default sample prompts.
10+
prompts = [
11+
"Hello, my name is",
12+
"The president of the United States is",
13+
"The capital of France is",
14+
"The future of AI is",
15+
]
16+
17+
if num_prompts != len(prompts):
18+
prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts]
19+
20+
return prompts
21+
22+
23+
def main(args):
24+
# Create prompts
25+
prompts = get_prompts(args.num_prompts)
26+
27+
# Create a sampling params object.
28+
sampling_params = SamplingParams(n=args.n,
29+
temperature=args.temperature,
30+
top_p=args.top_p,
31+
top_k=args.top_k,
32+
max_tokens=args.max_tokens)
33+
34+
# Create an LLM.
35+
# The default model is 'facebook/opt-125m'
36+
engine_args = EngineArgs.from_cli_args(args)
37+
llm = LLM(**asdict(engine_args))
38+
39+
# Generate texts from the prompts.
40+
# The output is a list of RequestOutput objects
41+
# that contain the prompt, generated text, and other information.
42+
outputs = llm.generate(prompts, sampling_params)
43+
# Print the outputs.
44+
for output in outputs:
45+
prompt = output.prompt
46+
generated_text = output.outputs[0].text
47+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
48+
49+
50+
if __name__ == '__main__':
51+
parser = FlexibleArgumentParser()
52+
parser = EngineArgs.add_cli_args(parser)
53+
group = parser.add_argument_group("SamplingParams options")
54+
group.add_argument("--num-prompts",
55+
type=int,
56+
default=4,
57+
help="Number of prompts used for inference")
58+
group.add_argument("--max-tokens",
59+
type=int,
60+
default=16,
61+
help="Generated output length for sampling")
62+
group.add_argument('--n',
63+
type=int,
64+
default=1,
65+
help='Number of generated sequences per prompt')
66+
group.add_argument('--temperature',
67+
type=float,
68+
default=0.8,
69+
help='Temperature for text generation')
70+
group.add_argument('--top-p',
71+
type=float,
72+
default=0.95,
73+
help='top_p for text generation')
74+
group.add_argument('--top-k',
75+
type=int,
76+
default=-1,
77+
help='top_k for text generation')
278

3-
# Sample prompts.
4-
prompts = [
5-
"Hello, my name is",
6-
"The president of the United States is",
7-
"The capital of France is",
8-
"The future of AI is",
9-
]
10-
# Create a sampling params object.
11-
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
12-
13-
# Create an LLM.
14-
llm = LLM(model="facebook/opt-125m")
15-
# Generate texts from the prompts. The output is a list of RequestOutput objects
16-
# that contain the prompt, generated text, and other information.
17-
outputs = llm.generate(prompts, sampling_params)
18-
# Print the outputs.
19-
for output in outputs:
20-
prompt = output.prompt
21-
generated_text = output.outputs[0].text
22-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
79+
args = parser.parse_args()
80+
main(args)

0 commit comments

Comments
 (0)