diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e33d5fb2dc24..d75d690cc66d 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -65,7 +65,9 @@ def run_to_completion(profile_dir: Optional[str] = None): if args.profile: profile_dir = args.profile_result_dir if not profile_dir: - profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" + profile_dir = Path( + "." + ) / "vllm_benchmark_result" / f"latency_result_{time.time()}" print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=args.profile_result_dir) return @@ -123,9 +125,7 @@ def run_to_completion(profile_dir: Optional[str] = None): '--profile-result-dir', type=str, default=None, - help=( - 'path to save the pytorch profiler output. Can be visualized ' - 'with ui.perfetto.dev or Tensorboard.' - )) + help=('path to save the pytorch profiler output. Can be visualized ' + 'with ui.perfetto.dev or Tensorboard.')) args = parser.parse_args() main(args) diff --git a/bin/build.sh b/bin/build.sh new file mode 100644 index 000000000000..6fe5313b1cd4 --- /dev/null +++ b/bin/build.sh @@ -0,0 +1,14 @@ +#!/bin/sh +function pkg_deploy() { + tar_src=`find dist -type f | grep "tar.gz" | grep -v cli` + pip install -q oss2 -i https://mirrors.aliyun.com/pypi/simple + python $ROOT/bin/osscli $tar_src alps/`basename $tar_src` +} +cd .. +rm -rf dist +python setup.py sdist + +ROOT=$(pwd) +echo $ROOT + +pkg_deploy diff --git a/bin/osscli b/bin/osscli new file mode 100644 index 000000000000..a2764b0ce258 --- /dev/null +++ b/bin/osscli @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +import oss2 +import sys + +def decode(key): + import base64 + return base64.standard_b64decode(key).decode() + +accessid=decode("TFRBSUN1S0hFSzFFNnNXNw==") +accesskey=decode("UFJ5YnNnVU1taHVKSUNXMUp3U0FKcDlwS1NjbXlR") +host="oss-cn-hangzhou-zmf.aliyuncs.com" +bucket="alps-common" +src=sys.argv[1] +target=sys.argv[2] +if target.endswith("/"): + import os + name = os.path.basename(src) + target = os.join(target, src) +oss_bucket = oss2.Bucket(oss2.Auth(accessid, accesskey), host, bucket, connect_timeout=300) +out = oss_bucket.put_object_from_file(target, src) +addr="http://%s.oss-cn-hangzhou-zmf.aliyuncs.com/%s" % (bucket, target) + +if out.status == 200: + print("Upload to %s " % addr) +else: + raise ValueError("Upload to %s " % addr) diff --git a/examples/openai_chatcompletion_client.py b/examples/openai_chatcompletion_client.py index 0b8e4b86ef5e..bbada3891bd1 100644 --- a/examples/openai_chatcompletion_client.py +++ b/examples/openai_chatcompletion_client.py @@ -32,6 +32,5 @@ model=model, ) - print("Chat completion results:") print(chat_completion) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 7a80c4ac49ab..58519f978d34 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -21,8 +21,7 @@ echo=False, n=2, stream=stream, - logprobs=3 -) + logprobs=3) print("Completion results:") if stream: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6c169d266905..3f29cdc74a44 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -4,7 +4,7 @@ from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union from vllm.config import CacheConfig, SchedulerConfig -from vllm.core.block_manager import AllocStatus, BlockSpaceManager +from vllm.core.block_manager import AllocStatus, BlockSpaceManager, BlockTable from vllm.core.policy import PolicyFactory from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, @@ -82,11 +82,31 @@ def __init__( self.running: Deque[SequenceGroup] = deque() # Sequence groups in the SWAPPED state. self.swapped: Deque[SequenceGroup] = deque() + self.prefix_cache: Dict[str, SequenceGroup] = {} def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. + if seq_group.prefix_name in self.prefix_cache: + prefix_seq = self.prefix_cache[seq_group.prefix_name].get_seqs()[0] + if prefix_seq.is_finished(): + seq_group.set_prefix_seq(prefix_seq) + else: + prefix = prefix_seq.prompt + prefix_token_ids = prefix_seq.data.prompt_token_ids + seqs = seq_group.get_seqs() + for seq in seqs: + seq_group.seqs_dict[seq.seq_id] = Sequence( + seq.seq_id, prefix + seq.prompt, + prefix_token_ids + seq.data.prompt_token_ids, + seq.block_size) self.waiting.append(seq_group) + def add_prefix_seq_groups(self, + seq_group_map: Dict[str, SequenceGroup]) -> None: + self.prefix_cache.update(seq_group_map) + for seq_group in seq_group_map.values(): + self.waiting.append(seq_group) + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: if isinstance(request_id, str): request_id = (request_id, ) @@ -294,6 +314,11 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_id = seq.seq_id seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) + self._reuse_prefix_cache( + seq_group=seq_group, + block_table=block_tables[seq_id], + blocks_to_copy=scheduler_outputs.blocks_to_copy + if scheduler_outputs.prompt_run else None) seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, @@ -309,7 +334,9 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_manager.fork(parent_seq, child_seq) def free_seq(self, seq: Sequence) -> None: - self.block_manager.free(seq) + # don't free prefix_table which has block_table. + if not seq.block_table: + self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: self.running = [ @@ -411,3 +438,29 @@ def _swap_out( blocks_to_swap_out.update(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED + + def _reuse_prefix_cache(self, seq_group: SequenceGroup, + block_table: BlockTable, + blocks_to_copy: Dict[int, List[int]]) -> None: + if seq_group.is_prefix: + assert len(seq_group.get_seqs() + ) == 1, "prefix_seq_group only has one sequence." + seq = seq_group.get_seqs()[0] + seq.block_table = block_table # cache prefix_block_table in prefix_seq + elif seq_group.prefix_seq: + # reuse prefix_block_table frome prefix_seq + prefix_seq = seq_group.prefix_seq + need_copy_last_block = prefix_seq.get_prompt_len( + ) % prefix_seq.block_size > 0 + prefix_block_num = len( + prefix_seq.block_table) - need_copy_last_block + block_table[: + prefix_block_num] = prefix_seq.block_table[: + prefix_block_num] + if need_copy_last_block and blocks_to_copy is not None: + src_block, dst_block = prefix_seq.block_table[-1], block_table[ + prefix_block_num] + if src_block in blocks_to_copy: + blocks_to_copy[src_block].append(dst_block) + else: + blocks_to_copy[src_block] = [dst_block] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 92f23ec29bfd..24bdb09636d4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -370,6 +370,7 @@ async def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + prefix_name: Optional[str] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -400,17 +401,18 @@ async def add_request( prompt=prompt, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + prefix_name=prefix_name) return stream async def generate( - self, - prompt: Optional[str], - sampling_params: SamplingParams, - request_id: str, - prompt_token_ids: Optional[List[int]] = None - ) -> AsyncIterator[RequestOutput]: + self, + prompt: Optional[str], + sampling_params: SamplingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None, + prefix_name: Optional[str] = None) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the @@ -424,6 +426,7 @@ async def generate( request_id: The unique id of the request. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. + prefix_name: the prefix_name which has cached by prompt cache. Yields: The output `RequestOutput` objects from the LLMEngine for the @@ -438,7 +441,8 @@ async def generate( prompt, sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + prefix_name=prefix_name) async for request_output in stream: yield request_output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1920946a31d7..158fc8522392 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -318,6 +318,7 @@ def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + prefix_name: Optional[str] = None, ) -> None: """Add a request to the engine's request pool. @@ -334,12 +335,22 @@ def add_request( use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use the current monotonic time. + prefix_name: the key of prefix_cache to reuse prefix_cache + to speedup prompt computing. """ if arrival_time is None: arrival_time = time.monotonic() if prompt_token_ids is None: assert prompt is not None prompt_token_ids = self.tokenizer.encode(prompt) + if prefix_name: + special_tokens = self.tokenizer.build_inputs_with_special_tokens( + [], []) + if special_tokens: + for idex, token_id in enumerate(prompt_token_ids): + if token_id not in special_tokens: + prompt_token_ids = prompt_token_ids[idex:] + break # Create the sequences. block_size = self.cache_config.block_size @@ -348,11 +359,47 @@ def add_request( # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time) + arrival_time, prefix_name) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) + def add_prefix_template(self, prefix_map: Dict[str, str]) -> None: + """ + prefix_map = {prefix_name: prefix_content or (prefix_content, prefix_token_ids)} + Generate the kv cache for the corresponding prefix_content, stored as promptcache. + Ensure that when a normal request is added, the required prefix_cache has already been created. + The current strategy is to concatenate the corresponding prefix_context if the promptcache has not been generated. + """ + greedy_sampling_params = SamplingParams(temperature=0.0, + use_beam_search=False, + max_tokens=1) + seq_group_map = {} + for prefix_name, con in prefix_map.items(): + if isinstance(con, tuple): + prefix_content, prefix_token_ids = con + else: + prefix_content = con + assert prefix_content is not None + prefix_token_ids = self.tokenizer.encode(prefix_content) + special_tokens = self.tokenizer.build_inputs_with_special_tokens( + [], []) + if special_tokens: + for idex in range(len(prefix_token_ids)): + if prefix_token_ids[-idex] not in special_tokens: + prefix_token_ids = prefix_token_ids[:len( + prefix_token_ids) - idex + 1] + break + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + seq = Sequence(seq_id, prefix_content, prefix_token_ids, + block_size) + # Create the sequence group. + seq_group_map[prefix_name] = SequenceGroup(None, [seq], + greedy_sampling_params, + None) + self.scheduler.add_prefix_seq_groups(seq_group_map) + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a request(s) with the given ID. @@ -602,6 +649,8 @@ def _process_model_outputs( # Create the outputs. request_outputs: List[RequestOutput] = [] for seq_group in scheduled_seq_groups: + if seq_group.is_prefix: + continue request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) for seq_group in scheduler_outputs.ignored_seq_groups: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0700298b03a3..ba86859d3342 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -174,6 +174,10 @@ def _add_request( self.llm_engine.add_request(request_id, prompt, sampling_params, prompt_token_ids) + def add_prefix_template(self, prefix_map: Dict[str, str]) -> None: + self.llm_engine.add_prefix_template(prefix_map) + self._run_engine(True) + def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. if use_tqdm: diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index da615ecccf99..24c9071389f2 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List import torch @@ -22,6 +22,8 @@ def __init__( context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], use_cuda_graph: bool, + prefix_len_list: Optional[List[int]] = None, + prefix_slot_mapping: Optional[torch.Tensor] = None, ) -> None: self.is_prompt = is_prompt self.max_context_len = max_context_len @@ -29,6 +31,8 @@ def __init__( self.context_lens = context_lens self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph + self.prefix_len_list = prefix_len_list + self.prefix_slot_mapping = prefix_slot_mapping # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 6482875d1c55..2bc2b6de3f49 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -1,11 +1,13 @@ """Multi-head attention.""" -from typing import List, Optional +from typing import List, Optional, Tuple import torch import torch.nn as nn from xformers import ops as xops from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, - LowerTriangularMaskWithTensorBias) + LowerTriangularMaskWithTensorBias, + BlockDiagonalCausalFromBottomRightMask + ) from vllm._C import ops from vllm._C import cache_ops @@ -99,6 +101,12 @@ def forward( ) if input_metadata.is_prompt: + if sum(input_metadata.prefix_len_list) > 0: + key, value = _concat_prefix_kvcache(input_metadata, key, value, + key_cache, value_cache, + self.num_kv_heads, + self.head_size) + # Prompt run. if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, @@ -120,17 +128,25 @@ def forward( # very attention layer of every iteration. # FIXME(woosuk): This is a hack. if input_metadata.attn_bias is None: - if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - input_metadata.attn_bias = attn_bias + if input_metadata.prefix_len_list: + prompt_lens = [seq_len] * batch_size + kv_lens = [seq_len + max(input_metadata.prefix_len_list) + ] * batch_size + input_metadata.attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( + q_seqlen=prompt_lens, kv_seqlen=kv_lens) + assert self.alibi_slopes is None, "current not support alibi" else: - input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, batch_size, + seq_len, query.dtype) # TODO(woosuk): Too many view operations. Let's try to reduce them # in the future for code readability. @@ -280,3 +296,64 @@ def _paged_attention( alibi_slopes, ) return output + + +def _concat_prefix_kvcache( + input_metadata: InputMetadata, key: torch.Tensor, value: torch.Tensor, + key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, + head_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + prefix_key = torch.empty(sum(input_metadata.prefix_len_list), + num_kv_heads, + head_size, + dtype=key.dtype, + device=key.device) + prefix_value = torch.empty_like(prefix_key) + cache_ops.gather_cached_kv( + prefix_key, + prefix_value, + key_cache, + value_cache, + input_metadata.prefix_slot_mapping, + ) + batch_size = len(input_metadata.prefix_len_list) + max_prompt_seq_len = key.shape[0] // batch_size + max_prefix_prompt_len = max_prompt_seq_len + max( + input_metadata.prefix_len_list) + if len(set(input_metadata.prefix_len_list)) == 1: + total_key = torch.concat( + (prefix_key.view(batch_size, -1, num_kv_heads, head_size), + key.view(batch_size, -1, num_kv_heads, head_size)), + dim=1).view(-1, num_kv_heads, head_size) + total_value = torch.concat( + (prefix_value.view(batch_size, -1, num_kv_heads, head_size), + value.view(batch_size, -1, num_kv_heads, head_size)), + dim=1).view(-1, num_kv_heads, head_size) + else: + total_key = torch.zeros(batch_size * max_prefix_prompt_len, + num_kv_heads, + head_size, + dtype=key.dtype, + device=key.device) + total_value = torch.zeros_like(total_key) + total_index = 0 + prefix_index = 0 + for i, prefix_len in enumerate(input_metadata.prefix_len_list): + total_key[total_index:total_index + + prefix_len] = prefix_key[prefix_index:prefix_index + + prefix_len] + total_value[total_index:total_index + + prefix_len] = prefix_value[prefix_index:prefix_index + + prefix_len] + total_index += prefix_len + total_key[total_index:total_index + + max_prompt_seq_len] = key[i * + max_prompt_seq_len:(i + 1) * + max_prompt_seq_len] + total_value[total_index:total_index + + max_prompt_seq_len] = value[i * + max_prompt_seq_len:(i + + 1) * + max_prompt_seq_len] + total_index += max_prefix_prompt_len - prefix_len + prefix_index += prefix_len + return total_key, total_value diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d36eeac0aa0..7c92e5385789 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -66,14 +66,16 @@ def __init__( ) -> None: self.prompt_token_ids = prompt_token_ids self.output_token_ids: List[int] = [] - self.cumulative_logprob = 0.0 + self.cumulative_logprob = 0. + self.prefix_len = 0 def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) self.cumulative_logprob += logprob def get_len(self) -> int: - return len(self.output_token_ids) + len(self.prompt_token_ids) + return len(self.output_token_ids) + len( + self.prompt_token_ids) + self.prefix_len def get_prompt_len(self) -> int: return len(self.prompt_token_ids) @@ -132,6 +134,7 @@ def __init__( self.read_offset = 0 # Input + output tokens self.tokens: Optional[List[str]] = None + self.block_table = None def _append_logical_block(self) -> None: block = LogicalTokenBlock( @@ -236,12 +239,24 @@ def __init__( seqs: List[Sequence], sampling_params: SamplingParams, arrival_time: float, + prefix_name: Optional[str] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time self.prompt_logprobs: Optional[PromptLogprobs] = None + self.prefix_seq = None + self.prefix_name = prefix_name + self.is_prefix = not request_id + + def set_prefix_seq(self, prefix_seq: Sequence): + self.prefix_seq = prefix_seq + for seq in self.seqs_dict.values(): + seq.logical_token_blocks = [] + seq._append_tokens_to_blocks(prefix_seq.data.prompt_token_ids + + seq.data.prompt_token_ids) + seq.data.prefix_len = self.prefix_seq.data.get_len() - 1 @property def prompt(self) -> str: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index be2803089f51..024711be6950 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -81,6 +81,10 @@ def _prepare_prompt( slot_mapping: List[List[int]] = [] prompt_lens: List[int] = [] + # prompt with prefix: positions_tensor should add prefix_len + prefix_slot_mapping = [] + prefix_len_list = [] + for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -88,6 +92,8 @@ def _prepare_prompt( seq_id = seq_ids[0] seq_data = seq_group_metadata.seq_data[seq_id] + prefix_len = seq_data.prefix_len + prefix_len_list.append(prefix_len) prompt_tokens = seq_data.get_token_ids() prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) @@ -95,7 +101,8 @@ def _prepare_prompt( input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.append(list(range(prompt_len))) + input_positions.append( + list(range(prefix_len, prompt_len + prefix_len))) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -113,16 +120,20 @@ def _prepare_prompt( # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - start_idx = max(0, prompt_len - self.sliding_window) - for i in range(prompt_len): - if i < start_idx: + start_idx = max(0, + prompt_len + prefix_len - self.sliding_window) + + for i in range(prompt_len + prefix_len): + if i >= prefix_len and i < start_idx: slot_mapping[-1].append(_PAD_SLOT_ID) continue - block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) + if i < prefix_len: + prefix_slot_mapping.append(slot) + else: + slot_mapping[-1].append(slot) max_prompt_len = max(prompt_lens) input_tokens = _make_tensor_with_pad(input_tokens, @@ -137,6 +148,9 @@ def _prepare_prompt( max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long) + prefix_slot_mapping = torch.tensor(prefix_slot_mapping, + dtype=torch.int, + device="cuda") input_metadata = InputMetadata( is_prompt=True, @@ -145,6 +159,8 @@ def _prepare_prompt( context_lens=None, block_tables=None, use_cuda_graph=False, + prefix_len_list=prefix_len_list, + prefix_slot_mapping=prefix_slot_mapping, ) return input_tokens, input_positions, input_metadata, prompt_lens