Skip to content

Conversation

@hzjane
Copy link
Collaborator

@hzjane hzjane commented Nov 26, 2024

Refer to #48.
Enable multimodal too.
Need to compare performance.

from time import time

from vllm import SamplingParams
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM

# Common prefix.
prefix = (
    "You are an expert school principal, skilled in effectively managing "
    "faculty and staff. Draft 10-15 questions for a potential first grade "
    "Head Teacher for my K-12, all-girls', independent school that emphasizes "
    "community, joyful discovery, and life-long learning. The candidate is "
    "coming in for a first-round panel interview for a 8th grade Math "
    "teaching role. They have 5 years of previous teaching experience "
    "as an assistant teacher at a co-ed, public school with experience "
    "in middle school math teaching. Based on these information, fulfill "
    "the following paragraph: ")

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

generating_prompts = [prefix + prompt for prompt in prompts]

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, max_tokens=32)

# Create an LLM.
regular_llm = LLM(
          model="/llm/models/Qwen1.5-7B-Chat",
          device="xpu",
          dtype="float16",
          trust_remote_code=True,
          enforce_eager=True,
          load_in_low_bit="sym_int4",
          #load_in_low_bit="fp8",
          tensor_parallel_size=1,
          disable_async_output_proc=True,
          #distributed_executor_backend="ray",
          max_model_len=1000,
          max_num_batched_tokens=1000,
          gpu_memory_utilization=0.45)

print("Results without `enable_prefix_caching`")

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = regular_llm.generate(generating_prompts, sampling_params)
start_time_regular = time()
outputs = regular_llm.generate(generating_prompts, sampling_params)
duration_regular = time() - start_time_regular
del regular_llm
import torch
torch.xpu.empty_cache()
prefix_cached_llm = LLM(
          model="/llm/models/Qwen1.5-7B-Chat",
          device="xpu",
          dtype="float16",
          trust_remote_code=True,
          enforce_eager=True,
          load_in_low_bit="sym_int4",
          #load_in_low_bit="fp8",
          enable_prefix_caching=True,
          tensor_parallel_size=1,
          disable_async_output_proc=True,
          #distributed_executor_backend="ray",
          max_model_len=1000,
          max_num_batched_tokens=1000,
          block_size=8, # must set to get the right output
          gpu_memory_utilization=0.45)

regular_generated_texts = []
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    regular_generated_texts.append(generated_text)
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params)

# Generate with prefix caching.
start_time_cached = time()
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
duration_cached = time() - start_time_cached

print("Results with `enable_prefix_caching`")

cached_generated_texts = []
# Print the outputs. You should see the same outputs as before.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    cached_generated_texts.append(generated_text)
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# Compare the results and display the speedup
generated_same = all([
    regular_generated_texts[i] == cached_generated_texts[i]
    for i in range(len(prompts))
])
for i in range(len(prompts)):
    print(regular_generated_texts[i] == cached_generated_texts[i])
print(f"Generated answers are the same: {generated_same}")

print(f"duration_regular {duration_regular}s")
print(f"duration_cached {duration_cached}s")
speedup = round(duration_regular / duration_cached, 2)
print(f"Speed up of cached generation compared to the regular is: {speedup}")

@hzjane hzjane force-pushed the add_prefix_caching branch from 29b3cbb to 10ad43c Compare November 27, 2024 06:04
Copy link
Collaborator

@gc-fu gc-fu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides these two comments, others LGTM

if computed_block_nums is not None:
context_len = len(computed_block_nums) * self.block_size
else:
context_len = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider chunked prefill, this should probably be seq_data.get_num_computed_tokens()?

# Last token
tokens = [seq_data.get_last_token_id()]

# FIXME: add prefix caching
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this FIXME?

@gc-fu
Copy link
Collaborator

gc-fu commented Nov 28, 2024

Also for multi-modal input, I am not very sure...

@gc-fu gc-fu merged commit 5136c06 into analytics-zoo:0.6.2 Nov 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants