diff --git a/components/backends/sglang/src/dynamo/sglang/health_check.py b/components/backends/sglang/src/dynamo/sglang/health_check.py new file mode 100644 index 0000000000..11ec4176e3 --- /dev/null +++ b/components/backends/sglang/src/dynamo/sglang/health_check.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +sglang-specific health check configuration. + +This module defines the default health check payload for sglang backends. +""" + +from dynamo.health_check import HealthCheckPayload + + +class SglangHealthCheckPayload(HealthCheckPayload): + """ + sglang-specific health check payload. + + Provides sglang defaults and inherits environment override support from base class. + """ + + def __init__(self): + """ + Initialize sglang health check payload with sglang-specific defaults. + + The format matches what DecodeWorkerHandler expects from the frontend. + """ + self.default_payload = { + "token_ids": [1], # Single token for minimal processing + "stop_conditions": { + "max_tokens": 1, # Generate only 1 token + "ignore_eos": False, + }, + "sampling_options": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": -1, + }, + "eos_token_ids": [], + "annotations": [], + } + super().__init__() + + +class SglangPrefillHealthCheckPayload(HealthCheckPayload): + """ + SGLang-specific health check payload for prefill workers in disaggregated mode. + + The prefill handler expects a wrapped structure with 'request' and 'sampling_params'. + """ + + def __init__(self): + """ + Initialize SGLang prefill health check payload with proper wrapped structure. + """ + self.default_payload = { + "request": { + "token_ids": [1], # Single token for minimal processing + }, + "sampling_params": { + "max_new_tokens": 1, # Generate only 1 token + "temperature": 0.0, + "top_p": 1.0, + "top_k": -1, + "ignore_eos": False, + }, + } + super().__init__() diff --git a/components/backends/sglang/src/dynamo/sglang/main.py b/components/backends/sglang/src/dynamo/sglang/main.py index f27cfe0efe..7dd0ad9cd1 100644 --- a/components/backends/sglang/src/dynamo/sglang/main.py +++ b/components/backends/sglang/src/dynamo/sglang/main.py @@ -15,6 +15,10 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging from dynamo.sglang.args import Config, DisaggregationMode, parse_args +from dynamo.sglang.health_check import ( + SglangHealthCheckPayload, + SglangPrefillHealthCheckPayload, +) from dynamo.sglang.publisher import setup_sgl_metrics from dynamo.sglang.register import register_llm_with_runtime_config from dynamo.sglang.request_handlers import DecodeWorkerHandler, PrefillWorkerHandler @@ -112,6 +116,8 @@ async def register_model(): ready_event.set() logging.info("Model registration succeeded; processing queued requests") + health_check_payload = SglangHealthCheckPayload().to_dict() + try: # Start endpoint immediately and register model concurrently # Requests queue until ready_event is set @@ -120,6 +126,7 @@ async def register_model(): handler.generate, graceful_shutdown=True, metrics_labels=metrics_labels, + health_check_payload=health_check_payload, ), register_model(), ) @@ -150,11 +157,14 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): handler = PrefillWorkerHandler(component, engine, config) + health_check_payload = SglangPrefillHealthCheckPayload().to_dict() + tasks = [ generate_endpoint.serve_endpoint( handler.generate, graceful_shutdown=True, metrics_labels=[("model", server_args.served_model_name)], + health_check_payload=health_check_payload, ) ] diff --git a/components/backends/sglang/src/dynamo/sglang/request_handlers/decode_handler.py b/components/backends/sglang/src/dynamo/sglang/request_handlers/decode_handler.py index 634af32a8d..4089c6e90b 100644 --- a/components/backends/sglang/src/dynamo/sglang/request_handlers/decode_handler.py +++ b/components/backends/sglang/src/dynamo/sglang/request_handlers/decode_handler.py @@ -53,7 +53,7 @@ def _build_sampling_params(self, request: dict) -> dict: sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"] return sampling_params - async def generate(self, request: str): + async def generate(self, request: dict): sampling_params = self._build_sampling_params(request) if self.serving_mode == DisaggregationMode.DECODE: @@ -62,7 +62,7 @@ async def generate(self, request: str): DisaggPreprocessedRequest( request=request, sampling_params=sampling_params, - ).model_dump_json() + ).model_dump() ) bootstrap_info = None diff --git a/components/backends/sglang/src/dynamo/sglang/request_handlers/prefill_handler.py b/components/backends/sglang/src/dynamo/sglang/request_handlers/prefill_handler.py index 21e34f578b..b28cf1b208 100644 --- a/components/backends/sglang/src/dynamo/sglang/request_handlers/prefill_handler.py +++ b/components/backends/sglang/src/dynamo/sglang/request_handlers/prefill_handler.py @@ -6,7 +6,6 @@ import random import socket -import msgspec import sglang as sgl from sglang.srt.utils import get_ip @@ -46,8 +45,7 @@ def _get_bootstrap_info(self): return bootstrap_host, bootstrap_port - async def generate(self, request: str): - req = msgspec.json.decode(request, type=dict) + async def generate(self, request: dict): bootstrap_room = self._generate_bootstrap_room() bootstrap_info = { @@ -59,8 +57,8 @@ async def generate(self, request: str): yield bootstrap_info results = await self.engine.async_generate( - input_ids=req["request"]["token_ids"], - sampling_params=req["sampling_params"], + input_ids=request["request"]["token_ids"], + sampling_params=request["sampling_params"], stream=True, bootstrap_host=self.bootstrap_host, bootstrap_port=self.bootstrap_port,