-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat(server): improve max tokens calculation #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@OlivierDehaene I didn't get a chance yet to respond to #226, will do soon. But I'm not sure about the change in this PR. Specifically, I don't think that only taking into account the number of tokens generated until a first request completes makes sense w.r.t. the batch inclusion decision. Yes the total token count will shrink at that point but it could continue to grow indefinitely large without any new requests being added, depending on how many tokens the other requests in the batch generate. Sorry if I've misunderstood something, which is very possible! |
You're correct. We discussed internally and decide to postpone for now. The core issue is that we're running in low free RAM mode and we want to squeeze things a bit too much. We're going to find a different route where we have more free space. |
@Narsil that's part of the reason for the complexity of my original impl, see the equivalent analysis here. You can use a cruder upper bound for the analysis which does make the logic simpler (as is done in the current main branch impl), but pay for it in reduced utilization. My PR only covers the flash attention (non-padded) case, give me a bit of time today, I will push the changes to make it generic (I have each impl in different branches right now but the delta is relatively minor). It would be good to talk it through with you guys if you are interested / have time! |
If I'm not mistaken, your original implem does allow request re-order, correct ? IMO this can be OK in some circumstances, but is probably not ideal in some circumstances.
In general if a system is overloaded, it's not going to behave nicely. Here we were attempting to work in low free RAM environment, but the truth is that we are low on RAM, so we're capped in the amount of parallelism we can use. Fixing this core issue in our hardware is the only viable solution IMO (since it's doable). |
@Narsil actually the reordering is an additional optimization, I was referring more to the batch admission logic. The logic I linked to does a "projection" for each possible next request to add to the batch, to determine whether its inclusion would result in the configured max batch "weight" (e.g. total token count) being exceeded immediately or any point in the future. I.e. it segments the future time-steps based on worst-case number of tokens that each request will generate, and ensures that the batch shape at the end of each of the segments fits. The reordering is in addition to this. If a particular next request doesn't fit according to the above logic, it will walk back up the queue to find other requests that do fit, but will only consider requests that arrived within a configurable time threshold of the request at the front of the queue (e.g. 1 second). This hopefully avoids the concerns that you described w.r.t. fairness and tail latencies. |
The goal of this is to maximize the throughput for a given GPU. Beyond that you can always scale up/out via bigger GPU or multiple replicas.
Are you referring to using larger GPUs / more of them? |
I should also mention that the "weight" is calculated differently for the non-flash case since the memory requirement is quadratic in sequence length... so I use batch_size * seq_len^2. |
Yes but in chat apps, you always run with very large max_new_tokens. Then if you jump in front of another request, it could mean a latency increase of 10s of seconds for the request being jumped. |
Pretty much. Just general better balance of RAM, FLOPS, $ :) (Here we're short on RAM for the requests since the model is taking up too much already, leaving not enough room, to feed the compute efficiently basically) |
But whatever resulting delay there is would have happened anyhow if the larger request had arrived 1 sec later. The queue-skipping can easily be disabled / doesn't need to be included. The motivation was for cases where there's a large request that needs some existing large requests in the batch to complete to make room, but in the meantime shorter requests could be processed without actually affecting how long the larger one has to wait. I did actually think about making that part more sophisticated, to allow queue-jumping only for such cases where the timestep that the skipped request would be introduced isn't affected (or perhaps causes it to be additionally delayed by some maximum time like 1 sec). But the arrival time threshold was simpler for a first impl. |
No description provided.