|
| 1 | +from dataclasses import asdict |
| 2 | + |
1 | 3 | from vllm import LLM, SamplingParams |
| 4 | +from vllm.engine.arg_utils import EngineArgs |
| 5 | +from vllm.utils import FlexibleArgumentParser |
| 6 | + |
| 7 | + |
| 8 | +def get_prompts(num_prompts: int): |
| 9 | + # The default sample prompts. |
| 10 | + prompts = [ |
| 11 | + "Hello, my name is", |
| 12 | + "The president of the United States is", |
| 13 | + "The capital of France is", |
| 14 | + "The future of AI is", |
| 15 | + ] |
| 16 | + |
| 17 | + if num_prompts != len(prompts): |
| 18 | + prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts] |
| 19 | + |
| 20 | + return prompts |
| 21 | + |
| 22 | + |
| 23 | +def main(args): |
| 24 | + # Create prompts |
| 25 | + prompts = get_prompts(args.num_prompts) |
| 26 | + |
| 27 | + # Create a sampling params object. |
| 28 | + sampling_params = SamplingParams(n=args.n, |
| 29 | + temperature=args.temperature, |
| 30 | + top_p=args.top_p, |
| 31 | + top_k=args.top_k, |
| 32 | + max_tokens=args.max_tokens) |
| 33 | + |
| 34 | + # Create an LLM. |
| 35 | + # The default model is 'facebook/opt-125m' |
| 36 | + engine_args = EngineArgs.from_cli_args(args) |
| 37 | + llm = LLM(**asdict(engine_args)) |
| 38 | + |
| 39 | + # Generate texts from the prompts. |
| 40 | + # The output is a list of RequestOutput objects |
| 41 | + # that contain the prompt, generated text, and other information. |
| 42 | + outputs = llm.generate(prompts, sampling_params) |
| 43 | + # Print the outputs. |
| 44 | + for output in outputs: |
| 45 | + prompt = output.prompt |
| 46 | + generated_text = output.outputs[0].text |
| 47 | + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 48 | + |
| 49 | + |
| 50 | +if __name__ == '__main__': |
| 51 | + parser = FlexibleArgumentParser() |
| 52 | + parser = EngineArgs.add_cli_args(parser) |
| 53 | + group = parser.add_argument_group("SamplingParams options") |
| 54 | + group.add_argument("--num-prompts", |
| 55 | + type=int, |
| 56 | + default=4, |
| 57 | + help="Number of prompts used for inference") |
| 58 | + group.add_argument("--max-tokens", |
| 59 | + type=int, |
| 60 | + default=16, |
| 61 | + help="Generated output length for sampling") |
| 62 | + group.add_argument('--n', |
| 63 | + type=int, |
| 64 | + default=1, |
| 65 | + help='Number of generated sequences per prompt') |
| 66 | + group.add_argument('--temperature', |
| 67 | + type=float, |
| 68 | + default=0.8, |
| 69 | + help='Temperature for text generation') |
| 70 | + group.add_argument('--top-p', |
| 71 | + type=float, |
| 72 | + default=0.95, |
| 73 | + help='top_p for text generation') |
| 74 | + group.add_argument('--top-k', |
| 75 | + type=int, |
| 76 | + default=-1, |
| 77 | + help='top_k for text generation') |
2 | 78 |
|
3 | | -# Sample prompts. |
4 | | -prompts = [ |
5 | | - "Hello, my name is", |
6 | | - "The president of the United States is", |
7 | | - "The capital of France is", |
8 | | - "The future of AI is", |
9 | | -] |
10 | | -# Create a sampling params object. |
11 | | -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) |
12 | | - |
13 | | -# Create an LLM. |
14 | | -llm = LLM(model="facebook/opt-125m") |
15 | | -# Generate texts from the prompts. The output is a list of RequestOutput objects |
16 | | -# that contain the prompt, generated text, and other information. |
17 | | -outputs = llm.generate(prompts, sampling_params) |
18 | | -# Print the outputs. |
19 | | -for output in outputs: |
20 | | - prompt = output.prompt |
21 | | - generated_text = output.outputs[0].text |
22 | | - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 79 | + args = parser.parse_args() |
| 80 | + main(args) |
0 commit comments