Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions charts/model-engine/templates/gateway_deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,6 @@ spec:
port: 5000
periodSeconds: 2
failureThreshold: 30
livenessProbe:
httpGet:
path: /healthz
port: 5000
initialDelaySeconds: 5
periodSeconds: 2
failureThreshold: 10
command:
- dumb-init
- --
Expand Down
48 changes: 35 additions & 13 deletions model-engine/model_engine_server/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path

import pytz
from fastapi import FastAPI, Request, Response
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from model_engine_server.api.batch_jobs_v1 import batch_job_router_v1
Expand All @@ -21,6 +21,7 @@
from model_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1
from model_engine_server.api.tasks_v1 import inference_task_router_v1
from model_engine_server.api.triggers_v1 import trigger_router_v1
from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter
from model_engine_server.core.loggers import (
LoggerTagKey,
LoggerTagManager,
Expand All @@ -32,12 +33,34 @@

logger = make_logger(logger_name())

# Allows us to make the Uvicorn worker concurrency in model_engine_server/api/worker.py very high
MAX_CONCURRENCY = 500

concurrency_limiter = MultiprocessingConcurrencyLimiter(
concurrency=MAX_CONCURRENCY, fail_on_concurrency_limit=True
)

healthcheck_routes = ["/healthcheck", "/healthz", "/readyz"]


class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
try:
LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4()))
return await call_next(request)
# we intentionally exclude healthcheck routes from the concurrency limiter
if request.url.path in healthcheck_routes:
return await call_next(request)
with concurrency_limiter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do people think about trying this out with just a specific route at first? Looking at the breakdown in the past week, get_/v1/async-tasks/_task_id is the most common route by far.

I'm not sure if it make sense to do a global limit, as we know some routes take more time than others.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://kubernetes.io/docs/tasks/configure-pod-container/configure-liveness-readiness-startup-probes/#define-a-liveness-http-request

Any code greater than or equal to 200 and less than 400 indicates success. Any other code indicates failure.

@squeakymouse From the docs, it looks like if our readiness probe route returns a 429, it would cause the pod to be marked as unready, and should result in 503s from istio again.

It is a little odd since our experimentation doesn't seem to show that...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think it shows up as the context deadline exceeded (Client.Timeout exceeded while awaiting headers) errors? 🤔

Does this mean I should try to exclude the healthcheck route from the concurrency limiting?

return await call_next(request)
except HTTPException as e:
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
return JSONResponse(
status_code=e.status_code,
content={
"error": e.detail,
"timestamp": timestamp,
},
)
except Exception as e:
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
Expand All @@ -49,14 +72,12 @@ async def dispatch(self, request: Request, call_next):
}
logger.error("Unhandled exception: %s", structured_log)
return JSONResponse(
{
"status_code": 500,
"content": {
"error": "Internal error occurred. Our team has been notified.",
"timestamp": timestamp,
"request_id": request_id,
},
}
status_code=500,
content={
"error": "Internal error occurred. Our team has been notified.",
"timestamp": timestamp,
"request_id": request_id,
},
)


Expand Down Expand Up @@ -91,9 +112,10 @@ def load_redis():
get_or_create_aioredis_pool()


@app.get("/healthcheck")
@app.get("/healthz")
@app.get("/readyz")
def healthcheck() -> Response:
"""Returns 200 if the app is healthy."""
return Response(status_code=200)


for endpoint in healthcheck_routes:
app.get(endpoint)(healthcheck)
5 changes: 3 additions & 2 deletions model-engine/model_engine_server/api/worker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from uvicorn.workers import UvicornWorker

# Gunicorn returns 503 instead of 429 when concurrency exceeds the limit, before adding rate limiting just increase the concurrency
# Gunicorn returns 503 instead of 429 when concurrency exceeds the limit
# We'll autoscale at target concurrency of a much lower number (around 50), and this just makes sure we don't 503 with bursty traffic
CONCURRENCY_LIMIT = 1000
# We set this very high since model_engine_server/api/app.py sets a lower per-pod concurrency at which we start returning 429s
CONCURRENCY_LIMIT = 10000


class LaunchWorker(UvicornWorker):
Expand Down
36 changes: 36 additions & 0 deletions model-engine/model_engine_server/common/concurrency_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from multiprocessing import BoundedSemaphore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wonder if we should be using an async semaphore rather than a multiprocessing one, since FastAPI is using async.

This SO comment seems to suggest that we should be using the corresponding semaphore:

use the correct type of semaphore for the form of concurrency being used

from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType
from typing import Optional

from fastapi import HTTPException
from model_engine_server.core.loggers import logger_name, make_logger

logger = make_logger(logger_name())


class MultiprocessingConcurrencyLimiter:
def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool):
self.concurrency = concurrency
if concurrency is not None:
if concurrency < 1:
raise ValueError("Concurrency should be at least 1")
self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency)
self.blocking = (
not fail_on_concurrency_limit
) # we want to block if we want to queue up requests
else:
self.semaphore = None
self.blocking = False # Unused

def __enter__(self):
logger.debug("Entering concurrency limiter semaphore")
if self.semaphore and not self.semaphore.acquire(block=self.blocking):
logger.warning(f"Too many requests (max {self.concurrency}), returning 429")
raise HTTPException(status_code=429, detail="Too many requests")
# Just raises an HTTPException.
# __exit__ should not run; otherwise the release() doesn't have an acquire()

def __exit__(self, type, value, traceback):
logger.debug("Exiting concurrency limiter semaphore")
if self.semaphore:
self.semaphore.release()
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import os
import subprocess
from functools import lru_cache
from multiprocessing import BoundedSemaphore
from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType
from typing import Optional

from fastapi import Depends, FastAPI, HTTPException
from fastapi import Depends, FastAPI
from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.inference.forwarding.forwarding import (
Expand All @@ -21,33 +19,6 @@
app = FastAPI()


class MultiprocessingConcurrencyLimiter:
def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool):
if concurrency is not None:
if concurrency < 1:
raise ValueError("Concurrency should be at least 1")
self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency)
self.blocking = (
not fail_on_concurrency_limit
) # we want to block if we want to queue up requests
else:
self.semaphore = None
self.blocking = False # Unused

def __enter__(self):
logger.debug("Entering concurrency limiter semaphore")
if self.semaphore and not self.semaphore.acquire(block=self.blocking):
logger.warning("Too many requests, returning 429")
raise HTTPException(status_code=429, detail="Too many requests")
# Just raises an HTTPException.
# __exit__ should not run; otherwise the release() doesn't have an acquire()

def __exit__(self, type, value, traceback):
logger.debug("Exiting concurrency limiter semaphore")
if self.semaphore:
self.semaphore.release()


@app.get("/healthz")
@app.get("/readyz")
def healthcheck():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import traceback
from functools import wraps
from multiprocessing import BoundedSemaphore
from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType
from typing import Optional

from fastapi import BackgroundTasks, FastAPI, HTTPException, Response, status
from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.inference.common import (
Expand All @@ -25,33 +23,6 @@
logger = make_logger(logger_name())


class MultiprocessingConcurrencyLimiter:
def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool):
if concurrency is not None:
if concurrency < 1:
raise ValueError("Concurrency should be at least 1")
self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency)
self.blocking = (
not fail_on_concurrency_limit
) # we want to block if we want to queue up requests
else:
self.semaphore = None
self.blocking = False # Unused

def __enter__(self):
logger.debug("Entering concurrency limiter semaphore")
if self.semaphore and not self.semaphore.acquire(block=self.blocking):
logger.warning("Too many requests, returning 429")
raise HTTPException(status_code=429, detail="Too many requests")
# Just raises an HTTPException.
# __exit__ should not run; otherwise the release() doesn't have an acquire()

def __exit__(self, type, value, traceback):
logger.debug("Exiting concurrency limiter semaphore")
if self.semaphore:
self.semaphore.release()


def with_concurrency_limit(concurrency_limiter: MultiprocessingConcurrencyLimiter):
def _inner(flask_func):
@wraps(flask_func)
Expand Down