From fedd473b88460d0c0356a7838c330c93e94a6117 Mon Sep 17 00:00:00 2001 From: Matt Kotzbauer Date: Fri, 2 May 2025 20:20:54 -0400 Subject: [PATCH 1/5] Initial attempt at gRPC metadata-based room assignment to support multiple quorums on a single LighthouseServer (https://github.com/pytorch/torchft/issues/173) --- Cargo.toml | 2 + src/bin/lighthouse.rs | 14 ++++-- src/lib.rs | 98 +++++++++++++++++++++++++++++------- src/lighthouse.rs | 2 +- src/router.rs | 88 ++++++++++++++++++++++++++++++++ torchft/multi_quorum_test.py | 46 +++++++++++++++++ 6 files changed, 228 insertions(+), 22 deletions(-) create mode 100644 src/router.rs create mode 100644 torchft/multi_quorum_test.py diff --git a/Cargo.toml b/Cargo.toml index 0c6ae6e9..6ff04507 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ askama = "0.12.1" atty = "0.2.14" axum = "0.7.7" chrono = "0.4.40" +dashmap = "6.1" fern = {version = "0.7.1", features = ["colored"]} gethostname = "0.5.0" log = "0.4.22" @@ -21,6 +22,7 @@ slog-stdlog = "4.1.1" stderrlog = "0.6.0" structopt = "0.3.26" tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] } +tokio-stream = "0.1" tonic = "0.12.2" [build-dependencies] diff --git a/src/bin/lighthouse.rs b/src/bin/lighthouse.rs index dbce458b..e9f944f5 100644 --- a/src/bin/lighthouse.rs +++ b/src/bin/lighthouse.rs @@ -4,8 +4,11 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. +use std::net::SocketAddr; use structopt::StructOpt; -use torchft::lighthouse::{Lighthouse, LighthouseOpt}; +use torchft::lighthouse::LighthouseOpt; +use torchft::router::Router; +use torchftpb::lighthouse_service_server::LighthouseServiceServer; #[tokio::main(flavor = "multi_thread", worker_threads = 4)] async fn main() { @@ -17,7 +20,10 @@ async fn main() { .unwrap(); let opt = LighthouseOpt::from_args(); - let lighthouse = Lighthouse::new(opt).await.unwrap(); - - lighthouse.run().await.unwrap(); + let router = Router::new(opt.clone()); + Server::builder() + .add_service(LighthouseServiceServer::new(router)) + .serve(opt.bind.parse::().unwrap()) + .await + .unwrap(); } diff --git a/src/lib.rs b/src/lib.rs index 5ef1bcfa..6fb0f0dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,8 +8,11 @@ pub mod lighthouse; pub mod manager; mod net; mod retry; +mod router; mod timeout; +pub use crate::router::Router; + use anyhow::Result; use atty::Stream; use core::time::Duration; @@ -21,6 +24,7 @@ use std::thread::available_parallelism; use structopt::StructOpt; use tokio::runtime::Runtime; use tokio::task::JoinHandle; +use tokio_stream::wrappers::TcpListenerStream; use tonic::transport::Channel; use tonic::Status; @@ -33,7 +37,9 @@ pub mod torchftpb { } use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; +use crate::torchftpb::lighthouse_service_server::LighthouseServiceServer; use crate::torchftpb::manager_service_client::ManagerServiceClient; +use crate::torchftpb::LighthouseHeartbeatRequest; use crate::torchftpb::{ CheckpointMetadataRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest, }; @@ -336,9 +342,12 @@ fn lighthouse_main(py: Python<'_>) -> PyResult<()> { } async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> { - let lighthouse = lighthouse::Lighthouse::new(opt).await?; + let router = Router::new(opt.clone()); - lighthouse.run().await?; + tonic::transport::Server::builder() + .add_service(LighthouseServiceServer::new(router)) + .serve(opt.bind.parse::()?) + .await?; Ok(()) } @@ -476,13 +485,19 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult { struct LighthouseClient { client: LighthouseServiceClient, runtime: Runtime, + room_id: Option, } #[pymethods] impl LighthouseClient { - #[pyo3(signature = (addr, connect_timeout))] + #[pyo3(signature = (addr, connect_timeout, room_id = None))] #[new] - fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult { + fn new( + py: Python<'_>, + addr: String, + connect_timeout: Duration, + room_id: Option, + ) -> PyResult { py.allow_threads(move || { let runtime = tokio::runtime::Builder::new_multi_thread() .worker_threads(num_threads()) @@ -495,6 +510,7 @@ impl LighthouseClient { Ok(Self { client: client, runtime: runtime, + room_id: room_id, }) }) } @@ -549,6 +565,8 @@ impl LighthouseClient { }), }); + let mut request = self.add_room_header(request); + // This timeout is processed on the server side so we also enable // keep alives to detect server health. request.set_timeout(timeout); @@ -562,6 +580,41 @@ impl LighthouseClient { }); Ok(convert_quorum(py, &quorum?)?) } + + /// Send a single heartbeat to the lighthouse. + /// + /// Args: + /// replica_id (str): The replica_id you registered with. + /// timeout (timedelta, optional): Per-RPC deadline. Default = 5 s. + #[pyo3(signature = (replica_id, timeout = Duration::from_secs(5)))] + fn heartbeat( + &self, + py: Python<'_>, + replica_id: String, + timeout: Duration, + ) -> Result<(), StatusError> { + py.allow_threads(move || { + let mut req = tonic::Request::new(LighthouseHeartbeatRequest { replica_id }); + let mut req = self.add_room_header(req); + req.set_timeout(timeout); + self.runtime.block_on(self.client.clone().heartbeat(req))?; + Ok(()) + }) + } +} + +impl LighthouseClient { + /// Attach `"room-id"` header if `self.room_id` is Some(_) + fn add_room_header(&self, mut req: tonic::Request) -> tonic::Request { + if let Some(ref id) = self.room_id { + use tonic::metadata::MetadataValue; + req.metadata_mut().insert( + crate::router::ROOM_ID_HEADER, + MetadataValue::try_from(id.as_str()).expect("room-id ascii"), + ); + } + req + } } /// LighthouseServer is a GRPC server for the lighthouse service. @@ -579,7 +632,7 @@ impl LighthouseClient { /// heartbeat_timeout_ms (int): The timeout for heartbeats. #[pyclass] struct LighthouseServer { - lighthouse: Arc, + bind: String, handle: JoinHandle>, _runtime: Runtime, } @@ -607,19 +660,30 @@ impl LighthouseServer { .enable_all() .build()?; - let lighthouse = rt - .block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt { - bind: bind, - min_replicas: min_replicas, - join_timeout_ms: join_timeout_ms, - quorum_tick_ms: quorum_tick_ms, - heartbeat_timeout_ms: heartbeat_timeout_ms, - })) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let opt = lighthouse::LighthouseOpt { + bind: bind.clone(), + min_replicas, + join_timeout_ms, + quorum_tick_ms, + heartbeat_timeout_ms, + }; + + let listener = rt.block_on(tokio::net::TcpListener::bind(&bind))?; + let bound_sock = listener.local_addr()?; + let bound = format!("http://{}", bound_sock); + let incoming = TcpListenerStream::new(listener); + + let handle = rt.spawn(async move { + tonic::transport::Server::builder() + .add_service(LighthouseServiceServer::new(Router::new(opt.clone()))) + .serve_with_incoming(incoming) + .await + .map_err(|e: tonic::transport::Error| anyhow::anyhow!(e)) + }); Ok(Self { - handle: rt.spawn(lighthouse.clone().run()), - lighthouse: lighthouse, + bind: bound, + handle, _runtime: rt, }) }) @@ -630,7 +694,7 @@ impl LighthouseServer { /// Returns: /// str: The address of the lighthouse server. fn address(&self) -> PyResult { - Ok(self.lighthouse.address().to_string()) + Ok(self.bind.clone()) } /// shutdown shuts down the lighthouse server. diff --git a/src/lighthouse.rs b/src/lighthouse.rs index fdf03aff..5c71a858 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -83,7 +83,7 @@ impl ChangeLogger { } } -#[derive(StructOpt, Debug)] +#[derive(StructOpt, Debug, Clone)] #[structopt()] pub struct LighthouseOpt { // bind is the address to bind the server to. diff --git a/src/router.rs b/src/router.rs new file mode 100644 index 00000000..fd521577 --- /dev/null +++ b/src/router.rs @@ -0,0 +1,88 @@ +use std::sync::Arc; + +use dashmap::{mapref::entry::Entry, DashMap}; +use tonic::{Request, Response, Status}; + +use crate::{ + lighthouse::{Lighthouse, LighthouseOpt}, + torchftpb::{ + lighthouse_service_server::LighthouseService, LighthouseHeartbeatRequest, + LighthouseHeartbeatResponse, LighthouseQuorumRequest, LighthouseQuorumResponse, + }, +}; + +/// Metadata header for both client and router +pub const ROOM_ID_HEADER: &str = "room-id"; + +/// Top-level service registered with tonic’s `Server::builder()` +#[derive(Clone)] +pub struct Router { + rooms: Arc>>, + tmpl_opt: LighthouseOpt, // (cloned for each new room) +} + +/// Designates a single tonic gRPC server into many logical “rooms.” +/// Inspects the `room-id` metadata header on each request, then +/// lazily creates or reuses an Arc for that namespace +impl Router { + /// Create a new router given the CLI/config options that are + /// normally passed straight to `Lighthouse::new`. + pub fn new(tmpl_opt: LighthouseOpt) -> Self { + Self { + rooms: Arc::new(DashMap::new()), + tmpl_opt, + } + } + + /// Room lookup: creation if it doesn't exist, access if it does + async fn room(&self, id: &str) -> Arc { + // 1. Quick optimistic read (no locking contention). + if let Some(handle) = self.rooms.get(id) { + return handle.clone(); + } + + // 2. Build the Lighthouse instance *off the map* so + // we don't hold any guard across `.await`. + let new_room = Lighthouse::new(self.tmpl_opt.clone()) + .await + .expect("failed to create Lighthouse"); + + // 3. Second pass: insert if still vacant, otherwise reuse + // whatever another task inserted first. + match self.rooms.entry(id.to_owned()) { + Entry::Occupied(entry) => entry.get().clone(), + Entry::Vacant(entry) => { + entry.insert(new_room.clone()); + new_room + } + } + } + + /// Extracts `"room-id"` from metadata, defaulting to `"default"`. + fn extract_room_id(meta: &tonic::metadata::MetadataMap) -> &str { + meta.get(ROOM_ID_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap_or("default") + } +} + +#[tonic::async_trait] +impl LighthouseService for Router { + async fn quorum( + &self, + req: Request, + ) -> Result, Status> { + let id = Self::extract_room_id(req.metadata()).to_owned(); + let room = self.room(&id).await; + as LighthouseService>::quorum(&room, req).await + } + + async fn heartbeat( + &self, + req: Request, + ) -> Result, Status> { + let id = Self::extract_room_id(req.metadata()).to_owned(); + let room = self.room(&id).await; + as LighthouseService>::heartbeat(&room, req).await + } +} diff --git a/torchft/multi_quorum_test.py b/torchft/multi_quorum_test.py new file mode 100644 index 00000000..e2ef132f --- /dev/null +++ b/torchft/multi_quorum_test.py @@ -0,0 +1,46 @@ +""" +Validate that one Lighthouse server can host isolated quorums +for multiple logical rooms (job IDs) via `room-id` metadata header. +""" + +from __future__ import annotations + +import datetime as _dt + +import pytest + +import torchft._torchft as ext + +_TIMEOUT = _dt.timedelta(seconds=3) # connect + RPC timeout + + +def _client(addr: str, room: str) -> ext.LighthouseClient: + """Utility: create a client with a logical room-id.""" + return ext.LighthouseClient(addr, _TIMEOUT, room) + + +@pytest.mark.asyncio +async def test_multi_room_quorums() -> None: + # 1) one server, any free port + server = ext.LighthouseServer("[::]:0", 1) + addr = server.address() + + # 2) two clients in two separate rooms + a = _client(addr, "jobA") + b = _client(addr, "jobB") + + # 3) explicit heartbeats (exercises RPC path) + a.heartbeat("a0") + b.heartbeat("b0") + + # 4) ask for a quorum from each room + qa = a.quorum("a0", _TIMEOUT) + qb = b.quorum("b0", _TIMEOUT) + + # 5) verify the rooms are independent + assert qa.quorum_id == qb.quorum_id == 1 + assert len(qa.participants) == 1 and qa.participants[0].replica_id == "a0" + assert len(qb.participants) == 1 and qb.participants[0].replica_id == "b0" + + # 6) shutdown + server.shutdown() From 5ab4c0cf500d5d0ecd78a7dbe061e669642883a9 Mon Sep 17 00:00:00 2001 From: Matt Kotzbauer Date: Wed, 7 May 2025 15:48:03 -0400 Subject: [PATCH 2/5] Interceptor attached via LighthouseClient constructor rather than using add_room_header for each RPC call --- src/bin/lighthouse.rs | 4 +++- src/interceptor.rs | 23 ++++++++++++++++++ src/interceptor.rs~ | 12 ++++++++++ src/lib.rs | 55 ++++++++++++++++++++----------------------- 4 files changed, 64 insertions(+), 30 deletions(-) create mode 100644 src/interceptor.rs create mode 100644 src/interceptor.rs~ diff --git a/src/bin/lighthouse.rs b/src/bin/lighthouse.rs index e9f944f5..92fc293f 100644 --- a/src/bin/lighthouse.rs +++ b/src/bin/lighthouse.rs @@ -6,9 +6,11 @@ use std::net::SocketAddr; use structopt::StructOpt; +use tonic::transport::Server; use torchft::lighthouse::LighthouseOpt; +use torchft::torchftpb::lighthouse_service_server::LighthouseServiceServer; use torchft::router::Router; -use torchftpb::lighthouse_service_server::LighthouseServiceServer; + #[tokio::main(flavor = "multi_thread", worker_threads = 4)] async fn main() { diff --git a/src/interceptor.rs b/src/interceptor.rs new file mode 100644 index 00000000..595408b9 --- /dev/null +++ b/src/interceptor.rs @@ -0,0 +1,23 @@ +use tonic::{service::Interceptor, metadata::MetadataValue, Request, Status}; + +/// Attaches user-assigned room-id header to every outbound RPC +#[derive(Clone)] +pub struct RoomIdInterceptor { + room: String, +} + +impl RoomIdInterceptor { + pub fn new(room: String) -> Self { + Self { room } + } +} + +impl Interceptor for RoomIdInterceptor { + fn call(&mut self, mut req: Request<()>) -> Result, Status> { + req.metadata_mut().insert( + crate::router::ROOM_ID_HEADER, + MetadataValue::try_from(self.room.as_str()).expect("ascii header"), + ); + Ok(req) + } +} diff --git a/src/interceptor.rs~ b/src/interceptor.rs~ new file mode 100644 index 00000000..2b3e8c3f --- /dev/null +++ b/src/interceptor.rs~ @@ -0,0 +1,12 @@ +use tonic::{Request, Status, service::Interceptor}; +use tonic::metadata::MetadataValue; + +pub fn room_id_interceptor(room: String) -> impl Interceptor { + move |mut req: Request<()>| { + req.metadata_mut().insert( + crate::router::ROOM_ID_HEADER, + MetadataValue::try_from(room.as_str()).expect("ascii header"), + ); + Ok(req) // returning Err(Status) would cancel the call :contentReference[oaicite:0]{index=0} + } +} diff --git a/src/lib.rs b/src/lib.rs index f9ccf780..087a03ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,10 +6,11 @@ pub mod lighthouse; pub mod manager; +pub mod router; mod net; mod retry; -mod router; mod timeout; +mod interceptor; pub use crate::router::Router; @@ -25,8 +26,9 @@ use structopt::StructOpt; use tokio::runtime::Runtime; use tokio::task::JoinHandle; use tokio_stream::wrappers::TcpListenerStream; -use tonic::transport::Channel; +use tonic::transport::{Channel, Endpoint}; use tonic::Status; +use tonic::service::interceptor::InterceptedService; use chrono::Local; use fern::colors::{Color, ColoredLevelConfig}; @@ -36,6 +38,7 @@ pub mod torchftpb { tonic::include_proto!("torchft"); } +use crate::interceptor::RoomIdInterceptor; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::lighthouse_service_server::LighthouseServiceServer; use crate::torchftpb::manager_service_client::ManagerServiceClient; @@ -486,9 +489,10 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult { /// connect_timeout (timedelta): The timeout for connecting to the lighthouse server. #[pyclass] struct LighthouseClient { - client: LighthouseServiceClient, + client: LighthouseServiceClient< + InterceptedService + >, runtime: Runtime, - room_id: Option, } #[pymethods] @@ -507,14 +511,25 @@ impl LighthouseClient { .thread_name("torchft-lhclnt") .enable_all() .build()?; - let client = runtime - .block_on(manager::lighthouse_client_new(addr, connect_timeout)) + + let endpoint = Endpoint::from_shared(addr.clone()) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - Ok(Self { - client: client, - runtime: runtime, - room_id: room_id, - }) + let channel = runtime + .block_on( + endpoint + .connect_timeout(connect_timeout) + .connect(), + ) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + let interceptor = + RoomIdInterceptor::new(room_id.unwrap_or_else(|| "default".to_owned())); + + let client = + LighthouseServiceClient::with_interceptor(channel, interceptor); + + Ok(Self { client, runtime }) + }) } @@ -569,8 +584,6 @@ impl LighthouseClient { }), }); - let mut request = self.add_room_header(request); - // This timeout is processed on the server side so we also enable // keep alives to detect server health. request.set_timeout(timeout); @@ -599,7 +612,6 @@ impl LighthouseClient { ) -> Result<(), StatusError> { py.allow_threads(move || { let mut req = tonic::Request::new(LighthouseHeartbeatRequest { replica_id }); - let mut req = self.add_room_header(req); req.set_timeout(timeout); self.runtime.block_on(self.client.clone().heartbeat(req))?; Ok(()) @@ -607,21 +619,6 @@ impl LighthouseClient { } } -impl LighthouseClient { - /// Attach `"room-id"` header if `self.room_id` is Some(_) - fn add_room_header(&self, mut req: tonic::Request) -> tonic::Request { - if let Some(ref id) = self.room_id { - use tonic::metadata::MetadataValue; - req.metadata_mut().insert( - crate::router::ROOM_ID_HEADER, - MetadataValue::try_from(id.as_str()).expect("room-id ascii"), - ); - } - req - } - -} - /// LighthouseServer is a GRPC server for the lighthouse service. /// /// It is used to coordinate the ManagerServer for each replica group. From 0a9ce34098c7fbcca543ca2e0789a5c686d16ca2 Mon Sep 17 00:00:00 2001 From: Matt Kotzbauer Date: Fri, 16 May 2025 17:44:44 -0400 Subject: [PATCH 3/5] Tonic-level routing changed to tower-level in src/router.rs - Server::builder calls (in src/bin/lighthouser.rs, src/lib.rs) and torchft/multi_quorum_test.py modified to reflect change. --- Cargo.toml | 4 ++ src/bin/lighthouse.rs | 8 +-- src/lib.rs | 33 ++++----- src/router.rs | 128 ++++++++++++++++++++--------------- torchft/multi_quorum_test.py | 48 +++++++------ 5 files changed, 119 insertions(+), 102 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6ff04507..9803a336 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,10 @@ axum = "0.7.7" chrono = "0.4.40" dashmap = "6.1" fern = {version = "0.7.1", features = ["colored"]} +futures = "0.3" gethostname = "0.5.0" +hyper = "0.14" +http = "0.2" log = "0.4.22" prost = "0.13.3" prost-types = "0.13.3" @@ -24,6 +27,7 @@ structopt = "0.3.26" tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] } tokio-stream = "0.1" tonic = "0.12.2" +tower = "0.4" [build-dependencies] tonic-build = "0.12.2" diff --git a/src/bin/lighthouse.rs b/src/bin/lighthouse.rs index 92fc293f..70c13f7c 100644 --- a/src/bin/lighthouse.rs +++ b/src/bin/lighthouse.rs @@ -8,10 +8,8 @@ use std::net::SocketAddr; use structopt::StructOpt; use tonic::transport::Server; use torchft::lighthouse::LighthouseOpt; -use torchft::torchftpb::lighthouse_service_server::LighthouseServiceServer; use torchft::router::Router; - #[tokio::main(flavor = "multi_thread", worker_threads = 4)] async fn main() { stderrlog::new() @@ -23,9 +21,11 @@ async fn main() { let opt = LighthouseOpt::from_args(); let router = Router::new(opt.clone()); + let addr: SocketAddr = opt.bind.parse().expect("invalid --bind address"); + Server::builder() - .add_service(LighthouseServiceServer::new(router)) - .serve(opt.bind.parse::().unwrap()) + .add_service(router) + .serve(addr) .await .unwrap(); } diff --git a/src/lib.rs b/src/lib.rs index 087a03ca..97865bab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,13 +4,13 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. +mod interceptor; pub mod lighthouse; pub mod manager; -pub mod router; mod net; mod retry; +pub mod router; mod timeout; -mod interceptor; pub use crate::router::Router; @@ -20,15 +20,16 @@ use core::time::Duration; use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; use std::cmp; use std::env; +use std::net::SocketAddr; use std::sync::Arc; use std::thread::available_parallelism; use structopt::StructOpt; use tokio::runtime::Runtime; use tokio::task::JoinHandle; use tokio_stream::wrappers::TcpListenerStream; +use tonic::service::interceptor::InterceptedService; use tonic::transport::{Channel, Endpoint}; use tonic::Status; -use tonic::service::interceptor::InterceptedService; use chrono::Local; use fern::colors::{Color, ColoredLevelConfig}; @@ -40,9 +41,7 @@ pub mod torchftpb { use crate::interceptor::RoomIdInterceptor; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; -use crate::torchftpb::lighthouse_service_server::LighthouseServiceServer; use crate::torchftpb::manager_service_client::ManagerServiceClient; -use crate::torchftpb::LighthouseHeartbeatRequest; use crate::torchftpb::{ CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest, @@ -349,10 +348,11 @@ fn lighthouse_main(py: Python<'_>) -> PyResult<()> { async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> { let router = Router::new(opt.clone()); + let addr: SocketAddr = opt.bind.parse()?; tonic::transport::Server::builder() - .add_service(LighthouseServiceServer::new(router)) - .serve(opt.bind.parse::()?) + .add_service(router) + .serve(addr) .await?; Ok(()) @@ -489,9 +489,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult { /// connect_timeout (timedelta): The timeout for connecting to the lighthouse server. #[pyclass] struct LighthouseClient { - client: LighthouseServiceClient< - InterceptedService - >, + client: LighthouseServiceClient>, runtime: Runtime, } @@ -515,21 +513,15 @@ impl LighthouseClient { let endpoint = Endpoint::from_shared(addr.clone()) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let channel = runtime - .block_on( - endpoint - .connect_timeout(connect_timeout) - .connect(), - ) + .block_on(endpoint.connect_timeout(connect_timeout).connect()) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let interceptor = RoomIdInterceptor::new(room_id.unwrap_or_else(|| "default".to_owned())); - let client = - LighthouseServiceClient::with_interceptor(channel, interceptor); + let client = LighthouseServiceClient::with_interceptor(channel, interceptor); - Ok(Self { client, runtime }) - + Ok(Self { client, runtime }) }) } @@ -674,10 +666,11 @@ impl LighthouseServer { let bound_sock = listener.local_addr()?; let bound = format!("http://{}", bound_sock); let incoming = TcpListenerStream::new(listener); + let router = Router::new(opt.clone()); let handle = rt.spawn(async move { tonic::transport::Server::builder() - .add_service(LighthouseServiceServer::new(Router::new(opt.clone()))) + .add_service(router) .serve_with_incoming(incoming) .await .map_err(|e: tonic::transport::Error| anyhow::anyhow!(e)) diff --git a/src/router.rs b/src/router.rs index fd521577..a9800192 100644 --- a/src/router.rs +++ b/src/router.rs @@ -1,32 +1,38 @@ -use std::sync::Arc; +use std::{ + convert::Infallible, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; use dashmap::{mapref::entry::Entry, DashMap}; -use tonic::{Request, Response, Status}; +use futures::FutureExt; +use tonic::{ + body::BoxBody, + codegen::http::{HeaderMap, Request, Response}, // http-0.2 types + server::NamedService, +}; +use tower::Service; use crate::{ lighthouse::{Lighthouse, LighthouseOpt}, - torchftpb::{ - lighthouse_service_server::LighthouseService, LighthouseHeartbeatRequest, - LighthouseHeartbeatResponse, LighthouseQuorumRequest, LighthouseQuorumResponse, - }, + torchftpb::lighthouse_service_server::LighthouseServiceServer, }; -/// Metadata header for both client and router +/// Metadata header recognised by both client interceptor and this router. pub const ROOM_ID_HEADER: &str = "room-id"; -/// Top-level service registered with tonic’s `Server::builder()` +/// gRPC server for a single room (inner state = `Arc`). +type GrpcSvc = LighthouseServiceServer>; + #[derive(Clone)] pub struct Router { - rooms: Arc>>, - tmpl_opt: LighthouseOpt, // (cloned for each new room) + rooms: Arc>>, + tmpl_opt: LighthouseOpt, } -/// Designates a single tonic gRPC server into many logical “rooms.” -/// Inspects the `room-id` metadata header on each request, then -/// lazily creates or reuses an Arc for that namespace impl Router { - /// Create a new router given the CLI/config options that are - /// normally passed straight to `Lighthouse::new`. pub fn new(tmpl_opt: LighthouseOpt) -> Self { Self { rooms: Arc::new(DashMap::new()), @@ -34,55 +40,71 @@ impl Router { } } - /// Room lookup: creation if it doesn't exist, access if it does - async fn room(&self, id: &str) -> Arc { - // 1. Quick optimistic read (no locking contention). - if let Some(handle) = self.rooms.get(id) { - return handle.clone(); + fn room_id(hdrs: &HeaderMap) -> &str { + hdrs.get(ROOM_ID_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap_or("default") + } + + async fn room_service( + rooms: Arc>>, + tmpl: LighthouseOpt, + id: &str, + ) -> Arc { + if let Some(svc) = rooms.get(id) { + return svc.clone(); } - // 2. Build the Lighthouse instance *off the map* so - // we don't hold any guard across `.await`. - let new_room = Lighthouse::new(self.tmpl_opt.clone()) + // Build room state once. + let lh = Lighthouse::new(tmpl.clone()) .await .expect("failed to create Lighthouse"); - // 3. Second pass: insert if still vacant, otherwise reuse - // whatever another task inserted first. - match self.rooms.entry(id.to_owned()) { - Entry::Occupied(entry) => entry.get().clone(), - Entry::Vacant(entry) => { - entry.insert(new_room.clone()); - new_room + let svc_new = Arc::new(LighthouseServiceServer::new(lh)); + + match rooms.entry(id.to_owned()) { + Entry::Occupied(e) => e.get().clone(), + Entry::Vacant(v) => { + v.insert(svc_new.clone()); + svc_new } } } - - /// Extracts `"room-id"` from metadata, defaulting to `"default"`. - fn extract_room_id(meta: &tonic::metadata::MetadataMap) -> &str { - meta.get(ROOM_ID_HEADER) - .and_then(|v| v.to_str().ok()) - .unwrap_or("default") - } } -#[tonic::async_trait] -impl LighthouseService for Router { - async fn quorum( - &self, - req: Request, - ) -> Result, Status> { - let id = Self::extract_room_id(req.metadata()).to_owned(); - let room = self.room(&id).await; - as LighthouseService>::quorum(&room, req).await +// Tower::Service implementation +impl Service> for Router { + type Response = Response; + type Error = Infallible; + type Future = + Pin> + Send + 'static>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } - async fn heartbeat( - &self, - req: Request, - ) -> Result, Status> { - let id = Self::extract_room_id(req.metadata()).to_owned(); - let room = self.room(&id).await; - as LighthouseService>::heartbeat(&room, req).await + fn call(&mut self, req: Request) -> Self::Future { + let rooms = self.rooms.clone(); + let tmpl = self.tmpl_opt.clone(); + let room = Self::room_id(req.headers()).to_owned(); + + async move { + let svc_arc = Self::room_service(rooms, tmpl, &room).await; + + // `Arc` itself isn’t a Service; clone the inner value. + let mut svc = (*svc_arc).clone(); + let resp = svc + .call(req) + .await + .map_err(|_e| -> Infallible { unreachable!() })?; + + Ok(resp) + } + .boxed() } } + +// Forward tonic’s NamedService marker +impl NamedService for Router { + const NAME: &'static str = ::NAME; +} diff --git a/torchft/multi_quorum_test.py b/torchft/multi_quorum_test.py index e2ef132f..8bfe16ec 100644 --- a/torchft/multi_quorum_test.py +++ b/torchft/multi_quorum_test.py @@ -1,46 +1,44 @@ -""" -Validate that one Lighthouse server can host isolated quorums -for multiple logical rooms (job IDs) via `room-id` metadata header. -""" - from __future__ import annotations import datetime as _dt +import time import pytest import torchft._torchft as ext -_TIMEOUT = _dt.timedelta(seconds=3) # connect + RPC timeout - +_TIMEOUT = _dt.timedelta(seconds=3) def _client(addr: str, room: str) -> ext.LighthouseClient: - """Utility: create a client with a logical room-id.""" + """Helper: create a LighthouseClient bound to a logical room.""" return ext.LighthouseClient(addr, _TIMEOUT, room) @pytest.mark.asyncio async def test_multi_room_quorums() -> None: - # 1) one server, any free port - server = ext.LighthouseServer("[::]:0", 1) - addr = server.address() + # 1) Launch one Lighthouse server on any free port + server = ext.LighthouseServer("[::]:0", min_replicas=1) + addr: str = server.address() + + # (give the Tokio runtime a tick to bind the listener) + time.sleep(0.1) - # 2) two clients in two separate rooms - a = _client(addr, "jobA") - b = _client(addr, "jobB") + # 2) Two clients, each in its own room + cli_a = _client(addr, "jobA") + cli_b = _client(addr, "jobB") - # 3) explicit heartbeats (exercises RPC path) - a.heartbeat("a0") - b.heartbeat("b0") + # 3) Explicit heart-beats (exercise the RPC path) + cli_a.heartbeat("a0") + cli_b.heartbeat("b0") - # 4) ask for a quorum from each room - qa = a.quorum("a0", _TIMEOUT) - qb = b.quorum("b0", _TIMEOUT) + # 4) Ask each room for a quorum + q_a = cli_a.quorum("a0", _TIMEOUT) + q_b = cli_b.quorum("b0", _TIMEOUT) - # 5) verify the rooms are independent - assert qa.quorum_id == qb.quorum_id == 1 - assert len(qa.participants) == 1 and qa.participants[0].replica_id == "a0" - assert len(qb.participants) == 1 and qb.participants[0].replica_id == "b0" + # 5) Assert the rooms are isolated + assert q_a.quorum_id == q_b.quorum_id == 1 + assert len(q_a.participants) == 1 and q_a.participants[0].replica_id == "a0" + assert len(q_b.participants) == 1 and q_b.participants[0].replica_id == "b0" - # 6) shutdown + # 6) Clean shutdown server.shutdown() From 273d3ee238c83f5926f905b8a4e9166f1d2296b7 Mon Sep 17 00:00:00 2001 From: Matt Kotzbauer Date: Wed, 21 May 2025 16:36:27 -0400 Subject: [PATCH 4/5] Edits to tower-based routing: src/router.rs room return type changed to Arc, Lighthouse::new now takes id prefix, test relocated to lighthouse_test.py and now uses coordination API, LighthouseServer now resolves host/port from the bound socket to give a routable http://host:port address --- src/interceptor.rs~ | 12 ---------- src/lib.rs | 11 +++++++-- src/lighthouse.rs | 4 +++- src/router.rs | 26 +++++++++------------ torchft/lighthouse_test.py | 30 ++++++++++++++++++++++++ torchft/multi_quorum_test.py | 44 ------------------------------------ 6 files changed, 53 insertions(+), 74 deletions(-) delete mode 100644 src/interceptor.rs~ delete mode 100644 torchft/multi_quorum_test.py diff --git a/src/interceptor.rs~ b/src/interceptor.rs~ deleted file mode 100644 index 2b3e8c3f..00000000 --- a/src/interceptor.rs~ +++ /dev/null @@ -1,12 +0,0 @@ -use tonic::{Request, Status, service::Interceptor}; -use tonic::metadata::MetadataValue; - -pub fn room_id_interceptor(room: String) -> impl Interceptor { - move |mut req: Request<()>| { - req.metadata_mut().insert( - crate::router::ROOM_ID_HEADER, - MetadataValue::try_from(room.as_str()).expect("ascii header"), - ); - Ok(req) // returning Err(Status) would cancel the call :contentReference[oaicite:0]{index=0} - } -} diff --git a/src/lib.rs b/src/lib.rs index 97865bab..7f0815fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ pub use crate::router::Router; use anyhow::Result; use atty::Stream; use core::time::Duration; +use gethostname::gethostname; use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; use std::cmp; use std::env; @@ -664,7 +665,6 @@ impl LighthouseServer { let listener = rt.block_on(tokio::net::TcpListener::bind(&bind))?; let bound_sock = listener.local_addr()?; - let bound = format!("http://{}", bound_sock); let incoming = TcpListenerStream::new(listener); let router = Router::new(opt.clone()); @@ -676,8 +676,15 @@ impl LighthouseServer { .map_err(|e: tonic::transport::Error| anyhow::anyhow!(e)) }); + let host = if bind.starts_with("0.0.0.0") || bind.starts_with("[::]") { + gethostname().to_string_lossy().into_owned() + } else { + bind.rsplit_once(':').map(|(h, _)| h.to_string()).unwrap() + }; + let public_addr = format!("http://{}:{}", host, bound_sock.port()); + Ok(Self { - bind: bound, + bind: public_addr, handle, _runtime: rt, }) diff --git a/src/lighthouse.rs b/src/lighthouse.rs index 2cdfda50..a8b082e9 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -58,6 +58,7 @@ struct State { } pub struct Lighthouse { + id: String, state: Mutex, opt: LighthouseOpt, listener: Mutex>, @@ -261,12 +262,13 @@ fn quorum_compute( } impl Lighthouse { - pub async fn new(opt: LighthouseOpt) -> Result> { + pub async fn new(id: String, opt: LighthouseOpt) -> Result> { let listener = tokio::net::TcpListener::bind(&opt.bind).await?; let (tx, _) = broadcast::channel(16); Ok(Arc::new(Self { + id: id, state: Mutex::new(State { participants: HashMap::new(), channel: tx, diff --git a/src/router.rs b/src/router.rs index a9800192..22d13546 100644 --- a/src/router.rs +++ b/src/router.rs @@ -10,7 +10,7 @@ use dashmap::{mapref::entry::Entry, DashMap}; use futures::FutureExt; use tonic::{ body::BoxBody, - codegen::http::{HeaderMap, Request, Response}, // http-0.2 types + codegen::http::{HeaderMap, Request, Response}, server::NamedService, }; use tower::Service; @@ -28,7 +28,7 @@ type GrpcSvc = LighthouseServiceServer>; #[derive(Clone)] pub struct Router { - rooms: Arc>>, + rooms: Arc>>, tmpl_opt: LighthouseOpt, } @@ -47,26 +47,23 @@ impl Router { } async fn room_service( - rooms: Arc>>, + rooms: Arc>>, tmpl: LighthouseOpt, id: &str, - ) -> Arc { - if let Some(svc) = rooms.get(id) { - return svc.clone(); + ) -> Arc { + if let Some(lh) = rooms.get(id) { + return lh.clone(); } - // Build room state once. - let lh = Lighthouse::new(tmpl.clone()) + let lh = Lighthouse::new(id.to_owned(), tmpl.clone()) .await .expect("failed to create Lighthouse"); - let svc_new = Arc::new(LighthouseServiceServer::new(lh)); - match rooms.entry(id.to_owned()) { Entry::Occupied(e) => e.get().clone(), Entry::Vacant(v) => { - v.insert(svc_new.clone()); - svc_new + v.insert(lh.clone()); + lh } } } @@ -89,10 +86,9 @@ impl Service> for Router { let room = Self::room_id(req.headers()).to_owned(); async move { - let svc_arc = Self::room_service(rooms, tmpl, &room).await; + let lh = Self::room_service(rooms, tmpl, &room).await; - // `Arc` itself isn’t a Service; clone the inner value. - let mut svc = (*svc_arc).clone(); + let mut svc = LighthouseServiceServer::new(lh); let resp = svc .call(req) .await diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py index 067a6222..c26cacd4 100644 --- a/torchft/lighthouse_test.py +++ b/torchft/lighthouse_test.py @@ -4,6 +4,7 @@ import torch.distributed as dist +import torchft.coordination as cd from torchft import Manager, ProcessGroupGloo from torchft._torchft import LighthouseClient, LighthouseServer, Quorum, QuorumMember @@ -155,3 +156,32 @@ def test_heartbeat_round_trip(self) -> None: finally: lighthouse.shutdown() + + def test_multi_room_quorums(self) -> None: + """One server, two logical rooms should yield two isolated quorums.""" + server = cd.LighthouseServer(bind="[::]:0", min_replicas=1) + addr = server.address() + + try: + # Two clients in two independent rooms + cli_a = cd.LighthouseClient(addr, timedelta(seconds=1), room_id="jobA") + cli_b = cd.LighthouseClient(addr, timedelta(seconds=1), room_id="jobB") + + # Explicit heartbeat so each room has one participant + cli_a.heartbeat("a0") + cli_b.heartbeat("b0") + + q_a = cli_a.quorum("a0", timedelta(seconds=1)) + q_b = cli_b.quorum("b0", timedelta(seconds=1)) + + # Both rooms got a quorum-id of 1 but with disjoint members + self.assertEqual(q_a.quorum_id, 1) + self.assertEqual(q_b.quorum_id, 1) + + self.assertEqual(len(q_a.participants), 1) + self.assertEqual(len(q_b.participants), 1) + self.assertEqual(q_a.participants[0].replica_id, "a0") + self.assertEqual(q_b.participants[0].replica_id, "b0") + + finally: + server.shutdown() diff --git a/torchft/multi_quorum_test.py b/torchft/multi_quorum_test.py deleted file mode 100644 index 8bfe16ec..00000000 --- a/torchft/multi_quorum_test.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -import datetime as _dt -import time - -import pytest - -import torchft._torchft as ext - -_TIMEOUT = _dt.timedelta(seconds=3) - -def _client(addr: str, room: str) -> ext.LighthouseClient: - """Helper: create a LighthouseClient bound to a logical room.""" - return ext.LighthouseClient(addr, _TIMEOUT, room) - - -@pytest.mark.asyncio -async def test_multi_room_quorums() -> None: - # 1) Launch one Lighthouse server on any free port - server = ext.LighthouseServer("[::]:0", min_replicas=1) - addr: str = server.address() - - # (give the Tokio runtime a tick to bind the listener) - time.sleep(0.1) - - # 2) Two clients, each in its own room - cli_a = _client(addr, "jobA") - cli_b = _client(addr, "jobB") - - # 3) Explicit heart-beats (exercise the RPC path) - cli_a.heartbeat("a0") - cli_b.heartbeat("b0") - - # 4) Ask each room for a quorum - q_a = cli_a.quorum("a0", _TIMEOUT) - q_b = cli_b.quorum("b0", _TIMEOUT) - - # 5) Assert the rooms are isolated - assert q_a.quorum_id == q_b.quorum_id == 1 - assert len(q_a.participants) == 1 and q_a.participants[0].replica_id == "a0" - assert len(q_b.participants) == 1 and q_b.participants[0].replica_id == "b0" - - # 6) Clean shutdown - server.shutdown() From 53ec8bec04dc62604011ce9210eabb244ee58a4a Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Thu, 29 May 2025 15:45:51 -0700 Subject: [PATCH 5/5] lint --- src/interceptor.rs | 2 +- src/lighthouse.rs | 6 ++-- src/manager.rs | 68 ++++++++++++++++++++++++++------------------ torchft/_torchft.pyi | 1 + 4 files changed, 45 insertions(+), 32 deletions(-) diff --git a/src/interceptor.rs b/src/interceptor.rs index 595408b9..fa7b7c14 100644 --- a/src/interceptor.rs +++ b/src/interceptor.rs @@ -1,4 +1,4 @@ -use tonic::{service::Interceptor, metadata::MetadataValue, Request, Status}; +use tonic::{metadata::MetadataValue, service::Interceptor, Request, Status}; /// Attaches user-assigned room-id header to every outbound RPC #[derive(Clone)] diff --git a/src/lighthouse.rs b/src/lighthouse.rs index a8b082e9..989cc13c 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -977,7 +977,7 @@ mod tests { quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, }; - let lighthouse = Lighthouse::new(opt).await?; + let lighthouse = Lighthouse::new("".to_string(), opt).await?; let lighthouse_task = tokio::spawn(lighthouse.clone().run()); @@ -1135,7 +1135,7 @@ mod tests { }; // Start the lighthouse service - let lighthouse = Lighthouse::new(opt).await?; + let lighthouse = Lighthouse::new("".to_string(), opt).await?; let lighthouse_task = tokio::spawn(lighthouse.clone().run()); // Create client to interact with lighthouse @@ -1242,7 +1242,7 @@ mod tests { }; // Start the lighthouse service - let lighthouse = Lighthouse::new(opt).await?; + let lighthouse = Lighthouse::new("".to_string(), opt).await?; let lighthouse_task = tokio::spawn(lighthouse.clone().run()); // Create client to interact with lighthouse diff --git a/src/manager.rs b/src/manager.rs index e28cbeb5..79d02657 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -544,13 +544,16 @@ mod tests { #[tokio::test] async fn test_should_commit() -> Result<()> { - let lighthouse = Lighthouse::new(LighthouseOpt { - bind: "[::]:0".to_string(), - join_timeout_ms: 100, - min_replicas: 1, - quorum_tick_ms: 100, - heartbeat_timeout_ms: 5000, - }) + let lighthouse = Lighthouse::new( + "".to_string(), + LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 1, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }, + ) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -591,13 +594,16 @@ mod tests { #[tokio::test] async fn test_get_quorum() -> Result<()> { - let lighthouse = Lighthouse::new(LighthouseOpt { - bind: "[::]:0".to_string(), - join_timeout_ms: 100, - min_replicas: 1, - quorum_tick_ms: 100, - heartbeat_timeout_ms: 5000, - }) + let lighthouse = Lighthouse::new( + "".to_string(), + LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 1, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }, + ) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -646,13 +652,16 @@ mod tests { #[tokio::test] async fn test_get_quorum_heal_first_step() -> Result<()> { - let lighthouse = Lighthouse::new(LighthouseOpt { - bind: "[::]:0".to_string(), - join_timeout_ms: 100, - min_replicas: 2, - quorum_tick_ms: 100, - heartbeat_timeout_ms: 5000, - }) + let lighthouse = Lighthouse::new( + "".to_string(), + LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 2, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }, + ) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -718,13 +727,16 @@ mod tests { #[tokio::test] async fn test_checkpoint_metadata() -> Result<()> { - let lighthouse = Lighthouse::new(LighthouseOpt { - bind: "[::]:0".to_string(), - join_timeout_ms: 100, - min_replicas: 1, - quorum_tick_ms: 100, - heartbeat_timeout_ms: 5000, - }) + let lighthouse = Lighthouse::new( + "".to_string(), + LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 1, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }, + ) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 9614d1b0..31eb3481 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -89,6 +89,7 @@ class Quorum: class LighthouseClient: addr: str connect_timeout: timedelta + room_id: Optional[str] = None def quorum( self,