Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions components/backends/sglang/src/dynamo/sglang/health_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
sglang-specific health check configuration.

This module defines the default health check payload for sglang backends.
"""

from dynamo.health_check import HealthCheckPayload


class SglangHealthCheckPayload(HealthCheckPayload):
"""
sglang-specific health check payload.

Provides sglang defaults and inherits environment override support from base class.
"""

def __init__(self):
"""
Initialize sglang health check payload with sglang-specific defaults.

The format matches what DecodeWorkerHandler expects from the frontend.
"""
self.default_payload = {
"token_ids": [1], # Single token for minimal processing
"stop_conditions": {
"max_tokens": 1, # Generate only 1 token
"ignore_eos": False,
},
"sampling_options": {
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
},
"eos_token_ids": [],
"annotations": [],
}
super().__init__()


class SglangPrefillHealthCheckPayload(HealthCheckPayload):
"""
SGLang-specific health check payload for prefill workers in disaggregated mode.

The prefill handler expects a wrapped structure with 'request' and 'sampling_params'.
"""

def __init__(self):
"""
Initialize SGLang prefill health check payload with proper wrapped structure.
"""
self.default_payload = {
"request": {
"token_ids": [1], # Single token for minimal processing
},
"sampling_params": {
"max_new_tokens": 1, # Generate only 1 token
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
"ignore_eos": False,
},
}
super().__init__()
10 changes: 10 additions & 0 deletions components/backends/sglang/src/dynamo/sglang/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.health_check import (
SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload,
)
from dynamo.sglang.publisher import setup_sgl_metrics
from dynamo.sglang.register import register_llm_with_runtime_config
from dynamo.sglang.request_handlers import DecodeWorkerHandler, PrefillWorkerHandler
Expand Down Expand Up @@ -112,6 +116,8 @@ async def register_model():
ready_event.set()
logging.info("Model registration succeeded; processing queued requests")

health_check_payload = SglangHealthCheckPayload().to_dict()

try:
# Start endpoint immediately and register model concurrently
# Requests queue until ready_event is set
Expand All @@ -120,6 +126,7 @@ async def register_model():
handler.generate,
graceful_shutdown=True,
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
),
register_model(),
)
Expand Down Expand Up @@ -150,11 +157,14 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):

handler = PrefillWorkerHandler(component, engine, config)

health_check_payload = SglangPrefillHealthCheckPayload().to_dict()

tasks = [
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
health_check_payload=health_check_payload,
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _build_sampling_params(self, request: dict) -> dict:
sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"]
return sampling_params

async def generate(self, request: str):
async def generate(self, request: dict):
sampling_params = self._build_sampling_params(request)

if self.serving_mode == DisaggregationMode.DECODE:
Expand All @@ -62,7 +62,7 @@ async def generate(self, request: str):
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump_json()
).model_dump()
)

bootstrap_info = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import random
import socket

import msgspec
import sglang as sgl
from sglang.srt.utils import get_ip

Expand Down Expand Up @@ -46,8 +45,7 @@ def _get_bootstrap_info(self):

return bootstrap_host, bootstrap_port

async def generate(self, request: str):
req = msgspec.json.decode(request, type=dict)
async def generate(self, request: dict):
bootstrap_room = self._generate_bootstrap_room()

bootstrap_info = {
Expand All @@ -59,8 +57,8 @@ async def generate(self, request: str):
yield bootstrap_info

results = await self.engine.async_generate(
input_ids=req["request"]["token_ids"],
sampling_params=req["sampling_params"],
input_ids=request["request"]["token_ids"],
sampling_params=request["sampling_params"],
stream=True,
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
Expand Down
Loading