diff --git a/tensorrt_llm/llmapi/disagg_utils.py b/tensorrt_llm/llmapi/disagg_utils.py index 049a49f04c4..42cff0b0601 100644 --- a/tensorrt_llm/llmapi/disagg_utils.py +++ b/tensorrt_llm/llmapi/disagg_utils.py @@ -58,6 +58,7 @@ class MetadataServerConfig(): hostname: str = "localhost" port: int = 2379 health_check_timeout: float = 5.0 + refresh_interval: float = 10.0 def parse_disagg_config_file(yaml_config_file: str): diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index c5300ae4bff..62a64a1cd73 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -35,20 +35,20 @@ class OpenAIDisaggServer: def __init__(self, - ctx_servers: List[str] = None, - gen_servers: List[str] = None, + ctx_servers: List[str], + gen_servers: List[str], req_timeout_secs: int = 180, server_start_timeout_secs: int = 180, ctx_router_config: Optional[RouterConfig] = None, gen_router_config: Optional[RouterConfig] = None, conditional_disagg_config: Optional[ConditionalDisaggConfig] = None, - metadata_server_cfg: MetadataServerConfig = None): + metadata_server_cfg: Optional[MetadataServerConfig] = None): self.ctx_servers = ctx_servers self.gen_servers = gen_servers self.metadata_server = create_metadata_server(metadata_server_cfg) - self.ctx_router = create_router(ctx_router_config, ctx_servers, self.metadata_server) - self.gen_router = create_router(gen_router_config, gen_servers, self.metadata_server) + self.ctx_router = create_router(ctx_router_config, ctx_servers, metadata_server_cfg, self.metadata_server) + self.gen_router = create_router(gen_router_config, gen_servers, metadata_server_cfg, self.metadata_server) self.conditional_disagg_config = conditional_disagg_config @@ -76,8 +76,8 @@ async def lifespan(app: FastAPI): if self.metadata_server: logger.info("Starting server monitoring via metadata service") - await self.ctx_router.start_server_monitoring() - await self.gen_router.start_server_monitoring() + await self.ctx_router.start_server_monitoring(metadata_server_cfg.refresh_interval) + await self.gen_router.start_server_monitoring(metadata_server_cfg.refresh_interval) yield diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index 3dabbfaef8d..233101a0e06 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -8,7 +8,8 @@ from tensorrt_llm.bindings.internal.batch_manager import (BlockKey, BlockKeyHasher) -from tensorrt_llm.llmapi.disagg_utils import RouterConfig, ServerRole +from tensorrt_llm.llmapi.disagg_utils import (MetadataServerConfig, + RouterConfig, ServerRole) from tensorrt_llm.logger import logger from tensorrt_llm.serve.metadata_server import JsonDictionary from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, @@ -144,18 +145,18 @@ def num_active_requests(self): class Router(ABC): - def __init__(self, - server_role: ServerRole, - servers: List[str] = None, - metadata_server: JsonDictionary = None): + def __init__(self, server_role: ServerRole, servers: List[str], + metadata_server_cfg: Optional[MetadataServerConfig], + metadata_server: Optional[JsonDictionary]): self._servers = servers or [] self._metadata_server = metadata_server self._server_role = server_role self._lock = asyncio.Lock() self._monitor_task = None self._session = None - self._health_check_timeout = 5.0 # Default timeout in seconds + self._health_check_timeout = metadata_server_cfg.health_check_timeout if metadata_server_cfg else None + @abstractmethod def _on_servers_updated(self, old_servers, new_servers): """Called when the server list changes. Override in subclasses to handle index resets. Args: @@ -171,12 +172,10 @@ async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]: async def finish_request(self, request: OpenAIRequest): pass - async def start_server_monitoring(self, poll_interval: int = 10): + async def start_server_monitoring(self, poll_interval: float = 10.0): """Start monitoring servers update from metadata service""" if not self._metadata_server: - logger.info( - "No metadata server configured, skipping server monitoring") - return + raise RuntimeError("Metadata server is not initialized") # Create a session for health checks if it doesn't exist if not self._session: @@ -212,55 +211,57 @@ async def close_session(self): logger.error(f"Error closing session: {e}") self._session = None - async def _monitor_servers(self, poll_interval: int = 10): + async def _monitor_servers(self, poll_interval: float = 10.0): while True: try: - if self._metadata_server: - # Get servers from metadata - server_key_map = await self.fetch_live_servers() - - # Check health and get live servers - live_servers = await self.check_servers_health( - server_key_map) - - # Filter by server role if needed - role_specific_servers = self._filter_servers_by_role( - live_servers, server_key_map) - - # Use filtered servers if available - final_servers = role_specific_servers if role_specific_servers else [] - - # Update server list - async with self._lock: - if final_servers != self._servers: - num_old_servers = len(self._servers) - old_servers = self._servers.copy() - self._servers = final_servers - num_new_servers = len(self._servers) - - # Call handler for server list changes - self._on_servers_updated(old_servers, self._servers) - - logger.info( - f"Updated {self._server_role} server list: {num_old_servers} -> {num_new_servers} servers" - ) - if logger.level == "debug" and self._servers: - for server in self._servers: - logger.debug(f" - {server}") - else: - logger.debug( - f"No change in {self._server_role} server list: {len(self._servers)} servers" - ) + # Get servers from metadata + server_key_map = await self.fetch_live_servers() + + # Check health and get live servers + live_servers = await self.check_servers_health(server_key_map) + + # Filter by server role if needed + role_specific_servers = self._filter_servers_by_role( + live_servers, server_key_map) + + # Use filtered servers if available + final_servers = role_specific_servers + + assert final_servers, f"No {self._server_role} servers available" + + # Update server list + async with self._lock: + if final_servers != self._servers: + old_servers = self._servers.copy() + self._servers = final_servers + + # Call handler for server list changes + self._on_servers_updated(old_servers, self._servers) + + # Log removed servers + for server in old_servers: + if server not in final_servers: + logger.info(f"Server {server} is removed") + + # Log added servers + for server in final_servers: + if server not in old_servers: + logger.info(f"Server {server} is added") + else: + logger.debug( + f"No change in {self._server_role} server list: {len(self._servers)} servers" + ) except Exception as e: logger.error(f"Error in server monitoring: {e}") + raise # Wait before next poll await asyncio.sleep(poll_interval) def _filter_servers_by_role(self, servers, server_key_map): """Filter servers by role (context or generation)""" - if not self._metadata_server or not servers: - return [] + if not servers: + raise RuntimeError("No servers available") filtered_servers = [] # Invert to get {url: key} for lookup @@ -271,29 +272,25 @@ def _filter_servers_by_role(self, servers, server_key_map): if key: server_metadata = self._metadata_server.get(key) if server_metadata: - # Use either server_type or server_role field - server_type = server_metadata.get('server_type', '').lower() - if not server_type: - server_type = server_metadata.get('server_role', - '').lower() - - # Extract port for visibility - parts = server_url.split(':') - if len(parts) >= 3: - parts[2] - - # Check if server type matches our role - if (self._server_role == ServerRole.CONTEXT and server_type == 'context') or \ - (self._server_role == ServerRole.GENERATION and server_type == 'generation'): + server_type = self._get_server_type(server_metadata) + + if self._is_matching_role(server_type): filtered_servers.append(server_url) return filtered_servers + def _get_server_type(self, server_metadata: dict) -> str: + return (server_metadata.get('server_type') + or server_metadata.get('server_role') or '').lower() + + def _is_matching_role(self, server_type: str) -> bool: + return (self._server_role == ServerRole.CONTEXT and server_type == 'context') or \ + (self._server_role == ServerRole.GENERATION and server_type == 'generation') + async def fetch_live_servers(self) -> Dict[str, str]: """Fetch all servers from metadata service and return {key: url} mapping""" if not self._metadata_server: - # Only use static servers if no metadata server - return {server: "" for server in self._servers} + raise RuntimeError("Metadata server is not initialized") # If metadata server is available, ignore static server list entirely server_key_map = {} @@ -303,37 +300,23 @@ async def fetch_live_servers(self) -> Dict[str, str]: logger.debug(f"Found {len(all_keys)} keys in metadata server") # Filter keys that start with 'trtllm/' and extract server metadata - matching_keys = 0 for key in all_keys: if key.startswith('trtllm/'): - matching_keys += 1 server_metadata = self._metadata_server.get(key) if server_metadata and isinstance( server_metadata, dict) and 'url' in server_metadata: server_key_map[key] = server_metadata['url'] - # Check if metadata includes health check timeout - if 'health_check_timeout' in server_metadata: - try: - self._health_check_timeout = float( - server_metadata['health_check_timeout']) - logger.debug( - f"Using health check timeout: {self._health_check_timeout}s" - ) - except (ValueError, TypeError): - logger.warning( - f"Invalid health_check_timeout value: {server_metadata['health_check_timeout']}" - ) - if server_key_map: - logger.info( + logger.debug( f"Using {len(server_key_map)} servers from metadata service" ) else: - logger.warning("No servers found in metadata service") + raise ValueError("No servers found in metadata service") except Exception as e: logger.error(f"Error fetching servers from metadata service: {e}") + raise return server_key_map @@ -343,66 +326,58 @@ async def check_servers_health(self, live_servers = [] dead_servers = [] - try: - # Check health of each server - for key, server_url in server_key_map.items(): - # First attempt - no printing errors - is_healthy = await self._check_server_health(server_url, - silent=True) + # Check health of each server + for key, server_url in server_key_map.items(): + try: + is_healthy = await self._check_server_health(server_url) # If first attempt failed, try again before declaring server dead if not is_healthy: # Second attempt - will print errors if it fails - is_healthy = await self._check_server_health(server_url, - silent=False) - - if not is_healthy: - # Only now add to dead servers - dead_servers.append((key, server_url)) - logger.warning( - f"Server {server_url} is not healthy after retry - removing" - ) - else: - live_servers.append(server_url) + is_healthy = await self._check_server_health(server_url) + + if not is_healthy: + # Only now add to dead servers + dead_servers.append((key, server_url)) else: live_servers.append(server_url) + except Exception as e: + logger.error( + f"Error checking health for server {server_url} (key: {key}): {e}" + ) + dead_servers.append((key, server_url)) - # Remove dead servers from etcd - for key, dead_server in dead_servers: - try: - logger.info( - f"Removing dead server {dead_server} from metadata server" - ) - self._metadata_server.remove(key) - except Exception as e: - logger.error( - f"Error removing dead server from metadata service: {e}" - ) - - except Exception as e: - logger.error(f"Error checking server health: {e}") + # Remove dead servers from etcd + for key, dead_server in dead_servers: + try: + logger.info( + f"Removing dead server {dead_server} from metadata server") + self._metadata_server.remove(key) + except Exception as e: + logger.error( + f"Error removing dead server from metadata service: {e}") + raise - return live_servers if live_servers else self._servers + return live_servers - async def _check_server_health(self, server_url, silent=False) -> bool: + async def _check_server_health(self, server_url) -> bool: """Check if a server is healthy by querying its health endpoint""" if not self._session: self._session = aiohttp.ClientSession() + assert self._health_check_timeout is not None, "health_check_timeout is not set" try: async with self._session.get( f"{server_url}/health", timeout=self._health_check_timeout) as response: if response.status != 200: - if not silent: - logger.warning( - f"Server {server_url} is not healthy (status: {response.status})" - ) + logger.warning( + f"Server {server_url} is not healthy (status: {response.status})" + ) return False return True except Exception as e: - if not silent: - logger.warning(f"Server {server_url} is not reachable: {e}") + logger.warning(f"Server {server_url} is not reachable: {e}") return False @@ -411,9 +386,11 @@ class RoundRobinRouter(Router): def __init__(self, server_role: ServerRole, servers: List[str] = None, + metadata_server_cfg: MetadataServerConfig = None, metadata_server: JsonDictionary = None, **kwargs): - super().__init__(server_role, servers, metadata_server) + super().__init__(server_role, servers, metadata_server_cfg, + metadata_server) self._server_idx = 0 def _on_servers_updated(self, old_servers, new_servers): @@ -452,10 +429,12 @@ class LoadBalancingRouter(Router): def __init__(self, server_role: ServerRole, servers: List[str] = None, + metadata_server_cfg: MetadataServerConfig = None, metadata_server: JsonDictionary = None, use_tokens: bool = False, **kwargs): - super().__init__(server_role, servers, metadata_server) + super().__init__(server_role, servers, metadata_server_cfg, + metadata_server) # Load map between servers and their number of tokens processed self._server_state = {} self._server_load_heap = [] @@ -535,12 +514,14 @@ class KvCacheAwareRouter(Router): def __init__(self, server_role: ServerRole = None, servers: list[str] = None, + metadata_server_cfg: MetadataServerConfig = None, metadata_server: JsonDictionary = None, use_tokens: bool = False, max_batch_size: int = 64, tokens_per_block: int = 32, **kwargs): - super().__init__(server_role, servers, metadata_server) + super().__init__(server_role, servers, metadata_server_cfg, + metadata_server) self._lock = asyncio.Lock() # Load map between servers and their number of tokens processed @@ -626,10 +607,15 @@ async def finish_request(self, await self._server_state[server].decrement_load(request, session=session) + def _on_servers_updated(self, old_servers, new_servers): + raise NotImplementedError( + "KvCacheAwareRouter does not support server updates") + def create_router(router_config: Optional[RouterConfig], - servers: List[str], - metadata_server: JsonDictionary = None) -> Router: + servers: Optional[List[str]], + metadata_server_cfg: Optional[MetadataServerConfig] = None, + metadata_server: Optional[JsonDictionary] = None) -> Router: """ Factory function to create different types of router instances. @@ -663,5 +649,5 @@ def create_router(router_config: Optional[RouterConfig], f"Supported types are: {list(router_map.keys())}") # Pass server_role as the first argument - return router_class(router_config.server_role, servers, metadata_server, - **router_config.args) + return router_class(router_config.server_role, servers, metadata_server_cfg, + metadata_server, **router_config.args) diff --git a/tests/unittest/disaggregated/test_router.py b/tests/unittest/disaggregated/test_router.py index 308c326fc20..eb4beffa4a0 100644 --- a/tests/unittest/disaggregated/test_router.py +++ b/tests/unittest/disaggregated/test_router.py @@ -21,12 +21,10 @@ def __init__(self): self.lock = threading.Lock() def get(self, key): - print("***** get *****", key) with self.lock: return self.servers.get(key) def put(self, key, value): - print("***** put *****", key, value) with self.lock: self.servers[key] = value return True @@ -39,7 +37,6 @@ def remove(self, key): return False def add_server(self, key, url): - print("***** add_server *****", key, url) with self.lock: self.servers[key] = url return True @@ -363,8 +360,8 @@ async def test_fetch_live_servers_context(mock_metadata_server, router_class): metadata_server=mock_metadata_server) # Initial check - should be no servers - servers = await router.fetch_live_servers() - assert len(servers) == 0, "Should have no servers initially" + with pytest.raises(ValueError): + servers = await router.fetch_live_servers() # Add a server server_key = "trtllm/server1"