Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions router/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ impl State {

let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0;
let mut max_decode_steps: u32 = u32::MAX;

// Pop entries starting from the front of the queue
while let Some((id, mut entry)) = self.entries.pop_front() {
Expand All @@ -182,7 +182,10 @@ impl State {
prefill_tokens += entry.request.input_length;
}

decode_tokens += entry.request.stopping_parameters.max_new_tokens;
max_decode_steps =
max_decode_steps.min(entry.request.stopping_parameters.max_new_tokens);

let decode_tokens = max_decode_steps * (batch_requests.len() + 1) as u32;

if (prefill_tokens + decode_tokens) > token_budget {
// Entry is over budget
Expand Down Expand Up @@ -236,6 +239,8 @@ impl State {
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);

let decode_tokens = size * max_decode_steps;

let batch = Batch {
id: self.next_batch_id,
requests: batch_requests,
Expand Down
51 changes: 38 additions & 13 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,50 +380,75 @@ pub enum ValidationError {
}

#[cfg(test)]
mod tests{
mod tests {
use super::*;
use std::io::Write;

#[tokio::test]
async fn test_validation_max_new_tokens(){
async fn test_validation_max_new_tokens() {
let tokenizer = None;
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_input_length,
max_total_tokens,
);

let max_new_tokens = 10;
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
match validation
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
Err(ValidationError::MaxNewTokens(1, 10)) => (),
_ => panic!("Unexpected not max new tokens")
_ => panic!("Unexpected not max new tokens"),
}
}

async fn get_tokenizer() -> Tokenizer{
if !std::path::Path::new("tokenizer.json").exists(){
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
let mut file = std::fs::File::create("tokenizer.json").unwrap();
async fn get_tokenizer() -> Tokenizer {
if !std::path::Path::new("tokenizer.json").exists() {
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
.await
.unwrap()
.bytes()
.await
.unwrap();
let mut file = std::fs::File::create("tokenizer.json").unwrap();
file.write_all(&content).unwrap();
}
Tokenizer::from_file("tokenizer.json").unwrap()
}

#[tokio::test]
async fn test_validation_input_length(){
async fn test_validation_input_length() {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_input_length,
max_total_tokens,
);

let max_new_tokens = 10;
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
match validation
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
_ => panic!("Unexpected not max new tokens")
_ => panic!("Unexpected not max new tokens"),
}
}
}
59 changes: 44 additions & 15 deletions server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class CausalLMBatch(Batch):

# Maximum number of tokens this batch will grow to
max_tokens: int
# Maximum number of decode steps before at least one request finish
max_decode_steps: int

# Past metadata
keys_head_dim_last: bool = True
Expand Down Expand Up @@ -77,7 +79,7 @@ def from_pb(
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
max_decode_steps = None
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
Expand All @@ -89,7 +91,15 @@ def from_pb(
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens

# Maximum number of decode steps before one request finish
if max_decode_steps is None:
max_decode_steps = stopping_criteria.max_new_tokens
else:
max_decode_steps = min(
max_decode_steps, stopping_criteria.max_new_tokens
)

padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
Expand Down Expand Up @@ -118,7 +128,10 @@ def from_pb(
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)

max_tokens = len(inputs) * max_input_length + max_decode_tokens
# Since we are sure that at least one request will be dropped in max_decode_steps,
# we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps
# before getting filtered and decreasing in size
max_tokens = len(inputs) * (max_input_length + max_decode_steps)

return cls(
batch_id=pb.id,
Expand All @@ -137,6 +150,7 @@ def from_pb(
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
max_decode_steps=max_decode_steps,
)

@tracer.start_as_current_span("filter")
Expand All @@ -159,8 +173,8 @@ def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatc
next_token_choosers = []
stopping_criterias = []

total_remaining_decode_tokens = 0
new_padding_right_offset = 0
max_decode_steps = None

for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
Expand All @@ -178,13 +192,17 @@ def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatc
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = (

# Remaining decode steps for this request
remaining_decode = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
if max_decode_steps is None:
max_decode_steps = remaining_decode
else:
max_decode_steps = min(max_decode_steps, remaining_decode)

new_padding_right_offset = max(new_padding_right_offset, remaining_decode)

# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
Expand Down Expand Up @@ -217,7 +235,10 @@ def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatc
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values

max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
# Since we are sure that at least one request will be dropped in max_decode_steps,
# we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps
# before getting filtered and decreasing in size
max_tokens = len(requests) * (max_input_length + max_decode_steps)

self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
Expand All @@ -232,6 +253,7 @@ def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatc
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
self.max_decode_steps = max_decode_steps

return self

Expand All @@ -256,14 +278,15 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0

# Batch tensors
input_ids = None
attention_mask = None
position_ids = None
past_key_values = []

max_decode_steps = None

# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
Expand Down Expand Up @@ -341,10 +364,11 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
layer[k] = t.view(len(batch), -1, *t.shape[-2:])

start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)

if max_decode_steps is None:
max_decode_steps = batch.max_decode_steps
else:
max_decode_steps = min(max_decode_steps, batch.max_decode_steps)

first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
Expand Down Expand Up @@ -417,6 +441,8 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":

past_key_values.append([padded_past_keys, padded_past_values])

max_tokens = len(requests) * (max_input_length + max_decode_steps)

return cls(
batch_id=batches[0].batch_id,
requests=requests,
Expand All @@ -435,6 +461,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
max_decode_steps=max_decode_steps,
)

def __len__(self):
Expand Down Expand Up @@ -636,6 +663,8 @@ def generate_token(
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Decrease right offset
batch.padding_right_offset -= 1
# Decrease max_decode_steps
batch.max_decode_steps -= 1

# Update position_ids
batch.position_ids = batch.position_ids[:, -1:] + 1
Expand Down
Loading