diff --git a/hyperactor/src/config.rs b/hyperactor/src/config.rs index f8ac31f7f..0aa44a609 100644 --- a/hyperactor/src/config.rs +++ b/hyperactor/src/config.rs @@ -240,6 +240,12 @@ pub mod global { CONFIG.read().unwrap().get(key).unwrap().clone() } + /// Get a key from the global configuration by cloning the value, + /// if it exists. Returns None if the key is not present. + pub fn try_get_cloned(key: Key) -> Option { + CONFIG.read().unwrap().get(key).cloned() + } + /// Get the global attrs pub fn attrs() -> Attrs { CONFIG.read().unwrap().clone() @@ -254,6 +260,12 @@ pub mod global { *config = Attrs::new(); } + /// Set a key in the global configuration. + pub fn set(key: Key, value: T) { + let mut config = CONFIG.write().unwrap(); + config.insert_value(key, Box::new(value)); + } + /// A guard that holds the global configuration lock and provides override functionality. /// /// This struct acts as both a lock guard (preventing other tests from modifying global config) diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index 2da0b1b31..8953813b6 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -86,7 +86,7 @@ use std::sync::RwLock; declare_attrs! { /// Default transport type to use across the application. @meta(CONFIG_ENV_VAR = "HYPERACTOR_MESH_DEFAULT_TRANSPORT".to_string()) - attr DEFAULT_TRANSPORT: ChannelTransport = ChannelTransport::Unix; + pub attr DEFAULT_TRANSPORT: ChannelTransport = ChannelTransport::Unix; } /// Get the default transport type to use across the application. diff --git a/monarch_hyperactor/src/channel.rs b/monarch_hyperactor/src/channel.rs index 7f152c6e9..007a8fd32 100644 --- a/monarch_hyperactor/src/channel.rs +++ b/monarch_hyperactor/src/channel.rs @@ -12,6 +12,7 @@ use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; use hyperactor::channel::TlsMode; use pyo3::exceptions::PyRuntimeError; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; /// Python binding for [`hyperactor::channel::ChannelTransport`] @@ -30,6 +31,33 @@ pub enum PyChannelTransport { // Sim(/*transport:*/ ChannelTransport), TODO kiuk@ add support } +#[pymethods] +impl PyChannelTransport { + fn get(&self) -> Self { + self.clone() + } +} + +impl TryFrom for PyChannelTransport { + type Error = PyErr; + + fn try_from(transport: ChannelTransport) -> PyResult { + match transport { + ChannelTransport::Tcp => Ok(PyChannelTransport::Tcp), + ChannelTransport::MetaTls(TlsMode::Hostname) => { + Ok(PyChannelTransport::MetaTlsWithHostname) + } + ChannelTransport::MetaTls(TlsMode::IpV6) => Ok(PyChannelTransport::MetaTlsWithIpV6), + ChannelTransport::Local => Ok(PyChannelTransport::Local), + ChannelTransport::Unix => Ok(PyChannelTransport::Unix), + _ => Err(PyValueError::new_err(format!( + "unsupported transport: {}", + transport + ))), + } + } +} + #[pyclass( name = "ChannelAddr", module = "monarch._rust_bindings.monarch_hyperactor.channel" diff --git a/monarch_hyperactor/src/config.rs b/monarch_hyperactor/src/config.rs index cb3edd8a8..c7c8dafbf 100644 --- a/monarch_hyperactor/src/config.rs +++ b/monarch_hyperactor/src/config.rs @@ -12,8 +12,11 @@ //! the base hyperactor configuration system. use hyperactor::attrs::declare_attrs; +use hyperactor_mesh::proc_mesh::DEFAULT_TRANSPORT; use pyo3::prelude::*; +use crate::channel::PyChannelTransport; + // Declare monarch-specific configuration keys declare_attrs! { /// Use a single asyncio runtime for all Python actors, rather than one per actor @@ -30,6 +33,99 @@ pub fn reload_config_from_env() -> PyResult<()> { Ok(()) } +struct ConfigKeyInfo { + register: fn(&Bound<'_, PyModule>) -> PyResult<()>, +} + +inventory::collect!(ConfigKeyInfo); + +macro_rules! register_config_key { + ($id:ident) => { + hyperactor::paste! { + hyperactor::submit! { + ConfigKeyInfo { + register: |module| { + module.add_class::<[]>() + } + } + } + } + }; +} + +fn _on_set() -> PyResult<()> { + Ok(()) +} + +/// Define python bindings to make +macro_rules! py_configurable { + (py, $id:ident, $py_name:literal, $py_ty:ty) => { + py_configurable!(py, $id, $py_name, $py_ty, _on_set); + }; + ($id:ident, $py_name:literal, $ty:ty) => { + py_configurable!($id, $py_name, $ty, _on_set); + }; + (py, $id:ident, $py_name:literal, $py_ty:ty, $on_set:ident) => { + hyperactor::paste! { + #[pyclass(name = $py_name, module = "monarch._rust_bindings.monarch_hyperactor.config", frozen)] + #[allow(non_camel_case_types)] + #[derive(Clone)] + struct []; + + #[pymethods] + impl [] { + #[staticmethod] + fn get() -> PyResult> { + hyperactor::config::global::try_get_cloned($id) + .map(|val| val.try_into()) + .transpose() + } + + #[staticmethod] + fn set(val: &$py_ty) -> PyResult<()> { + hyperactor::config::global::set($id, val.clone().into()); + $on_set() + } + } + + register_config_key!($id); + } + }; + ($id:ident, $py_name:literal, $ty:ty, $on_set:expr) => { + hyperactor::paste! { + #[pyclass(name = $py_name, module = "monarch._rust_bindings.monarch_hyperactor.config", frozen)] + #[allow(non_camel_case_types)] + #[derive(Clone)] + struct []; + + #[pymethods] + impl [] { + #[staticmethod] + fn get() -> Option<$ty> { + hyperactor::config::global::try_get_cloned($id) + } + + #[staticmethod] + fn set(val: $ty) -> PyResult<()> { + hyperactor::config::global::set($id, val); + $on_set() + } + } + + register_config_key!($id); + } + }; +} + +// TODO(slurye): Add a callback to re-initialize the root client +// when default transport changes. +py_configurable!( + py, + DEFAULT_TRANSPORT, + "DefaultTransport", + PyChannelTransport +); + /// Register Python bindings for the config module pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { let reload = wrap_pyfunction!(reload_config_from_env, module)?; @@ -38,5 +134,10 @@ pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { "monarch._rust_bindings.monarch_hyperactor.config", )?; module.add_function(reload)?; + + for key in inventory::iter::() { + (key.register)(module)?; + } + Ok(()) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi index bc8eb6737..bfec3fa53 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi @@ -10,7 +10,8 @@ from enum import Enum class ChannelTransport(Enum): Tcp = "tcp" - MetaTls = "metatls" + MetaTlsWithHostname = "metatls(hostname)" + MetaTlsWithIpV6 = "metatls(ipv6)" Local = "local" Unix = "unix" # Sim # TODO add support diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi index 068ae3906..05b453dba 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi @@ -10,6 +10,10 @@ Type hints for the monarch_hyperactor.config Rust bindings. """ +from typing import Generic, TypeVar + +from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport + def reload_config_from_env() -> None: """ Reload configuration from environment variables. @@ -18,3 +22,18 @@ def reload_config_from_env() -> None: the global configuration. """ ... + +T = TypeVar("T") + +# ConfigKey isn't actually a class that exists, +# and the rust configuration keys like DefaultTransport +# don't share a common subclass. But this is nice for +# type-checking and not having to stub out get and set +# methods for every config key. +class ConfigKey(Generic[T]): + @staticmethod + def get() -> T: ... + @staticmethod + def set(val: T) -> None: ... + +class DefaultTransport(ConfigKey[ChannelTransport]): ... diff --git a/python/tests/test_config.py b/python/tests/test_config.py new file mode 100644 index 000000000..d2f6f245b --- /dev/null +++ b/python/tests/test_config.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport +from monarch._rust_bindings.monarch_hyperactor.config import DefaultTransport + + +def test_get_set_transport() -> None: + DefaultTransport.set(ChannelTransport.Tcp) + assert DefaultTransport.get() == ChannelTransport.Tcp + DefaultTransport.set(ChannelTransport.MetaTlsWithHostname) + assert DefaultTransport.get() == ChannelTransport.MetaTlsWithHostname + DefaultTransport.set(ChannelTransport.MetaTlsWithIpV6) + assert DefaultTransport.get() == ChannelTransport.MetaTlsWithIpV6