Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions hyperactor/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,12 @@ pub mod global {
CONFIG.read().unwrap().get(key).unwrap().clone()
}

/// Get a key from the global configuration by cloning the value,
/// if it exists. Returns None if the key is not present.
pub fn try_get_cloned<T: AttrValue>(key: Key<T>) -> Option<T> {
CONFIG.read().unwrap().get(key).cloned()
}

/// Get the global attrs
pub fn attrs() -> Attrs {
CONFIG.read().unwrap().clone()
Expand All @@ -254,6 +260,12 @@ pub mod global {
*config = Attrs::new();
}

/// Set a key in the global configuration.
pub fn set<T: AttrValue>(key: Key<T>, value: T) {
let mut config = CONFIG.write().unwrap();
config.insert_value(key, Box::new(value));
}

/// A guard that holds the global configuration lock and provides override functionality.
///
/// This struct acts as both a lock guard (preventing other tests from modifying global config)
Expand Down
2 changes: 1 addition & 1 deletion hyperactor_mesh/src/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ use std::sync::RwLock;
declare_attrs! {
/// Default transport type to use across the application.
@meta(CONFIG_ENV_VAR = "HYPERACTOR_MESH_DEFAULT_TRANSPORT".to_string())
attr DEFAULT_TRANSPORT: ChannelTransport = ChannelTransport::Unix;
pub attr DEFAULT_TRANSPORT: ChannelTransport = ChannelTransport::Unix;
}

/// Get the default transport type to use across the application.
Expand Down
28 changes: 28 additions & 0 deletions monarch_hyperactor/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use hyperactor::channel::ChannelAddr;
use hyperactor::channel::ChannelTransport;
use hyperactor::channel::TlsMode;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

/// Python binding for [`hyperactor::channel::ChannelTransport`]
Expand All @@ -30,6 +31,33 @@ pub enum PyChannelTransport {
// Sim(/*transport:*/ ChannelTransport), TODO kiuk@ add support
}

#[pymethods]
impl PyChannelTransport {
fn get(&self) -> Self {
self.clone()
}
}

impl TryFrom<ChannelTransport> for PyChannelTransport {
type Error = PyErr;

fn try_from(transport: ChannelTransport) -> PyResult<Self> {
match transport {
ChannelTransport::Tcp => Ok(PyChannelTransport::Tcp),
ChannelTransport::MetaTls(TlsMode::Hostname) => {
Ok(PyChannelTransport::MetaTlsWithHostname)
}
ChannelTransport::MetaTls(TlsMode::IpV6) => Ok(PyChannelTransport::MetaTlsWithIpV6),
ChannelTransport::Local => Ok(PyChannelTransport::Local),
ChannelTransport::Unix => Ok(PyChannelTransport::Unix),
_ => Err(PyValueError::new_err(format!(
"unsupported transport: {}",
transport
))),
}
}
}

#[pyclass(
name = "ChannelAddr",
module = "monarch._rust_bindings.monarch_hyperactor.channel"
Expand Down
101 changes: 101 additions & 0 deletions monarch_hyperactor/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
//! the base hyperactor configuration system.

use hyperactor::attrs::declare_attrs;
use hyperactor_mesh::proc_mesh::DEFAULT_TRANSPORT;
use pyo3::prelude::*;

use crate::channel::PyChannelTransport;

// Declare monarch-specific configuration keys
declare_attrs! {
/// Use a single asyncio runtime for all Python actors, rather than one per actor
Expand All @@ -30,6 +33,99 @@ pub fn reload_config_from_env() -> PyResult<()> {
Ok(())
}

struct ConfigKeyInfo {
register: fn(&Bound<'_, PyModule>) -> PyResult<()>,
}

inventory::collect!(ConfigKeyInfo);

macro_rules! register_config_key {
($id:ident) => {
hyperactor::paste! {
hyperactor::submit! {
ConfigKeyInfo {
register: |module| {
module.add_class::<[<PY_ $id>]>()
}
}
}
}
};
}

fn _on_set() -> PyResult<()> {
Ok(())
}

/// Define python bindings to make
macro_rules! py_configurable {
(py, $id:ident, $py_name:literal, $py_ty:ty) => {
py_configurable!(py, $id, $py_name, $py_ty, _on_set);
};
($id:ident, $py_name:literal, $ty:ty) => {
py_configurable!($id, $py_name, $ty, _on_set);
};
(py, $id:ident, $py_name:literal, $py_ty:ty, $on_set:ident) => {
hyperactor::paste! {
#[pyclass(name = $py_name, module = "monarch._rust_bindings.monarch_hyperactor.config", frozen)]
#[allow(non_camel_case_types)]
#[derive(Clone)]
struct [<PY_ $id>];

#[pymethods]
impl [<PY_ $id>] {
#[staticmethod]
fn get() -> PyResult<Option<$py_ty>> {
hyperactor::config::global::try_get_cloned($id)
.map(|val| val.try_into())
.transpose()
}

#[staticmethod]
fn set(val: &$py_ty) -> PyResult<()> {
hyperactor::config::global::set($id, val.clone().into());
$on_set()
}
}

register_config_key!($id);
}
};
($id:ident, $py_name:literal, $ty:ty, $on_set:expr) => {
hyperactor::paste! {
#[pyclass(name = $py_name, module = "monarch._rust_bindings.monarch_hyperactor.config", frozen)]
#[allow(non_camel_case_types)]
#[derive(Clone)]
struct [<PY_ $id>];

#[pymethods]
impl [<PY_ $id>] {
#[staticmethod]
fn get() -> Option<$ty> {
hyperactor::config::global::try_get_cloned($id)
}

#[staticmethod]
fn set(val: $ty) -> PyResult<()> {
hyperactor::config::global::set($id, val);
$on_set()
}
}

register_config_key!($id);
}
};
}

// TODO(slurye): Add a callback to re-initialize the root client
// when default transport changes.
py_configurable!(
py,
DEFAULT_TRANSPORT,
"DefaultTransport",
PyChannelTransport
);

/// Register Python bindings for the config module
pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
let reload = wrap_pyfunction!(reload_config_from_env, module)?;
Expand All @@ -38,5 +134,10 @@ pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
"monarch._rust_bindings.monarch_hyperactor.config",
)?;
module.add_function(reload)?;

for key in inventory::iter::<ConfigKeyInfo>() {
(key.register)(module)?;
}

Ok(())
}
3 changes: 2 additions & 1 deletion python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ from enum import Enum

class ChannelTransport(Enum):
Tcp = "tcp"
MetaTls = "metatls"
MetaTlsWithHostname = "metatls(hostname)"
MetaTlsWithIpV6 = "metatls(ipv6)"
Local = "local"
Unix = "unix"
# Sim # TODO add support
Expand Down
19 changes: 19 additions & 0 deletions python/monarch/_rust_bindings/monarch_hyperactor/config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
Type hints for the monarch_hyperactor.config Rust bindings.
"""

from typing import Generic, TypeVar

from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport

def reload_config_from_env() -> None:
"""
Reload configuration from environment variables.
Expand All @@ -18,3 +22,18 @@ def reload_config_from_env() -> None:
the global configuration.
"""
...

T = TypeVar("T")

# ConfigKey isn't actually a class that exists,
# and the rust configuration keys like DefaultTransport
# don't share a common subclass. But this is nice for
# type-checking and not having to stub out get and set
# methods for every config key.
class ConfigKey(Generic[T]):
@staticmethod
def get() -> T: ...
@staticmethod
def set(val: T) -> None: ...

class DefaultTransport(ConfigKey[ChannelTransport]): ...
19 changes: 19 additions & 0 deletions python/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
from monarch._rust_bindings.monarch_hyperactor.config import DefaultTransport


def test_get_set_transport() -> None:
DefaultTransport.set(ChannelTransport.Tcp)
assert DefaultTransport.get() == ChannelTransport.Tcp
DefaultTransport.set(ChannelTransport.MetaTlsWithHostname)
assert DefaultTransport.get() == ChannelTransport.MetaTlsWithHostname
DefaultTransport.set(ChannelTransport.MetaTlsWithIpV6)
assert DefaultTransport.get() == ChannelTransport.MetaTlsWithIpV6
Loading