feat(router): Dynamic batch sizing #210
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
Currently to avoid OOM you must set a "worst case" max batch size based on the desired max sequence length. This means that (a) throughput is unnecessarily limited when there are many shorter sequences and (b) you have to be pretty conservative about the max context length offered.
These changes introduce a maximum batch "weight" parameter which in the flash attention case corresponds to a maximum total number of tokens in the batch. The idea is that this is roughly proportional to the memory requirement.
max_new_tokens
valuesIf
max_batch_weight
is not set, it just infers this from themax_batch_size
andmax_total_tokens
args. In this case it should behave roughly the same as it does now, so could hopefully be a "non breaking" change for existing configurationsIt turns out to be simpler to configure for a particular model/GPU. The precise values for
max_batch_size
andmax_sequence_length
no longer matter much, they can both be set quite high. You just need to determine one number (the max weight / total tokens), which is easy to do with minimal experimentation.We have been using this successfully for a while now and it means we can support a much higher throughput / volume of users with the same hardware while offering larger context lengths. For example, we have a deployment of GPT-NeoX 20B on one 80GB A100 with the max batch size set to 256 and the max sequence length (max_total_tokens) set to 8192. The actual batch size flexes automatically as needed. Our
max_batch_weight
setting for this is 10k.Details/caveats
I've only included the implementation for the flash attention case so far. The additions to generalize to the regular attention case aren't very big (we run non flash-attention models with this too), but I thought this was probably complicated enough to start with. It will need to support general case of course before actually being included.next_batch
should return immediately before getting into the more complex logic.next_batch
function now takes the current entries map instead of a min and max, the tests inqueue.rs
would need updating, so I just removed them for now.