Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions hyperactor/src/attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,12 @@ impl Attrs {
) {
self.values.insert(name, value);
}

/// Internal getter by key name for explicitly-set values (no
/// defaults).
pub(crate) fn get_value_by_name(&self, name: &'static str) -> Option<&dyn SerializableValue> {
self.values.get(name).map(|b| b.as_ref())
}
}

impl Clone for Attrs {
Expand Down
276 changes: 17 additions & 259 deletions hyperactor/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,24 @@
* LICENSE file in the root directory of this source tree.
*/

//! Configuration for Hyperactor.
//! Configuration keys and I/O for hyperactor.
//!
//! This module provides a centralized way to manage configuration settings for Hyperactor.
//! It uses the attrs system for type-safe, flexible configuration management that supports
//! environment variables, YAML files, and temporary modifications for tests.
//! This module declares all config keys (`declare_attrs!`) and
//! provides helpers to load/save `Attrs` (from env via `from_env`,
//! from YAML via `from_yaml`, and `to_yaml`). It also re-exports the
//! process-wide layered store under [`crate::config::global`].
//!
//! For reading/writing the process-global configuration (layered
//! resolution, test overrides), see [`crate::config::global`].

/// Global layered configuration store.
///
/// This submodule defines the process-wide configuration layers
/// (`File`, `Env`, `Runtime`, and `TestOverride`), resolution order,
/// and guard types (`ConfigLock`, `ConfigValueGuard`) used for
/// testing. Use this when you need to read or temporarily override
/// values in the global configuration state.
pub mod global;

use std::env;
use std::fs::File;
Expand Down Expand Up @@ -171,172 +184,6 @@ pub fn to_yaml<P: AsRef<Path>>(attrs: &Attrs, path: P) -> Result<(), anyhow::Err
Ok(())
}

/// Global configuration functions
///
/// This module provides global configuration access and testing utilities.
///
/// # Testing with Global Configuration
///
/// Tests can override global configuration using [`global::lock`]. This ensures that
/// such tests are serialized (and cannot clobber each other's overrides).
///
/// ```ignore rust
/// #[test]
/// fn test_my_feature() {
/// let config = hyperactor::config::global::lock();
/// let _guard = config.override_key(SOME_CONFIG_KEY, test_value);
/// // ... test logic here ...
/// }
/// ```
pub mod global {
use std::marker::PhantomData;

use super::*;
use crate::attrs::AttrValue;
use crate::attrs::Key;

/// Global configuration instance, initialized from environment variables.
static CONFIG: LazyLock<Arc<RwLock<Attrs>>> =
LazyLock::new(|| Arc::new(RwLock::new(from_env())));

/// Acquire the global configuration lock for testing.
///
/// This function returns a ConfigLock that acts as both a write lock guard (preventing
/// other tests from modifying global config concurrently) and as the only way to
/// create configuration overrides.
///
/// Example usage:
/// ```ignore rust
/// let config = hyperactor::config::global::lock();
/// let _guard = config.override_key(CONFIG_KEY, "value");
/// // ... test code using the overridden config ...
/// ```
pub fn lock() -> ConfigLock {
static MUTEX: LazyLock<std::sync::Mutex<()>> = LazyLock::new(|| std::sync::Mutex::new(()));
ConfigLock {
_guard: MUTEX.lock().unwrap(),
}
}

/// Initialize the global configuration from environment variables
pub fn init_from_env() {
let config = from_env();
let mut global_config = CONFIG.write().unwrap();
*global_config = config;
}

/// Initialize the global configuration from a YAML file
pub fn init_from_yaml<P: AsRef<Path>>(path: P) -> Result<(), anyhow::Error> {
let config = from_yaml(path)?;
let mut global_config = CONFIG.write().unwrap();
*global_config = config;
Ok(())
}

/// Get a key from the global configuration. Currently only available for Copy types.
/// `get` assumes that the key has a default value.
pub fn get<T: AttrValue + Copy>(key: Key<T>) -> T {
*CONFIG.read().unwrap().get(key).unwrap()
}

/// Get a key from the global configuration by cloning the value.
pub fn get_cloned<T: AttrValue>(key: Key<T>) -> T {
CONFIG.read().unwrap().get(key).unwrap().clone()
}

/// Get the global attrs
pub fn attrs() -> Attrs {
CONFIG.read().unwrap().clone()
}

/// Reset the global configuration to defaults (for testing only)
///
/// Note: This should be called from within with_test_lock() to ensure thread safety.
/// Available in all builds to support tests in other crates.
pub fn reset_to_defaults() {
let mut config = CONFIG.write().unwrap();
*config = Attrs::new();
}

/// A guard that holds the global configuration lock and provides override functionality.
///
/// This struct acts as both a lock guard (preventing other tests from modifying global config)
/// and as the only way to create configuration overrides. Override guards cannot outlive
/// this ConfigLock, ensuring proper synchronization.
pub struct ConfigLock {
_guard: std::sync::MutexGuard<'static, ()>,
}

impl ConfigLock {
/// Create a configuration override that will be restored when the guard is dropped.
///
/// The returned guard must not outlive this ConfigLock.
pub fn override_key<'a, T: AttrValue>(
&'a self,
key: crate::attrs::Key<T>,
value: T,
) -> ConfigValueGuard<'a, T> {
let orig = {
let mut config = CONFIG.write().unwrap();
let orig = config.remove_value(key);
config.set(key, value.clone());
orig
};

let orig_env = if let Some(env_var) = key.attrs().get(CONFIG_ENV_VAR) {
let orig = std::env::var(env_var).ok();
// SAFETY: this is used in tests
unsafe {
std::env::set_var(env_var, value.display());
}
Some((env_var.clone(), orig))
} else {
None
};

ConfigValueGuard {
key,
orig,
orig_env,
_phantom: PhantomData,
}
}
}

/// A guard that restores a single configuration value when dropped
pub struct ConfigValueGuard<'a, T: 'static> {
key: crate::attrs::Key<T>,
orig: Option<Box<dyn crate::attrs::SerializableValue>>,
orig_env: Option<(String, Option<String>)>,
// This is here so we can hold onto a 'a lifetime.
_phantom: PhantomData<&'a ()>,
}

impl<T: 'static> Drop for ConfigValueGuard<'_, T> {
fn drop(&mut self) {
let mut config = CONFIG.write().unwrap();
if let Some(orig) = self.orig.take() {
config.insert_value(self.key, orig);
} else {
config.remove_value(self.key);
}
if let Some((key, value)) = self.orig_env.take() {
if let Some(value) = value {
// SAFETY: this is used in tests
unsafe {
std::env::set_var(key, value);
}
} else {
// SAFETY: this is used in tests
unsafe {
std::env::remove_var(&key);
}
}
}
}
}
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;
Expand Down Expand Up @@ -429,29 +276,6 @@ mod tests {
unsafe { std::env::remove_var("HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS") };
}

#[test]
fn test_global_config() {
let config = global::lock();

// Reset global config to defaults to avoid interference from other tests
global::reset_to_defaults();

assert_eq!(
global::get(CODEC_MAX_FRAME_LENGTH),
CODEC_MAX_FRAME_LENGTH_DEFAULT
);
{
let _guard = config.override_key(CODEC_MAX_FRAME_LENGTH, 1024);
assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 1024);
// The configuration will be automatically restored when _guard goes out of scope
}

assert_eq!(
global::get(CODEC_MAX_FRAME_LENGTH),
CODEC_MAX_FRAME_LENGTH_DEFAULT
);
}

#[test]
fn test_defaults() {
// Test that empty config now returns defaults via get_or_default
Expand Down Expand Up @@ -523,70 +347,4 @@ mod tests {
Duration::from_secs(30)
);
}

#[test]
fn test_overrides() {
let config = global::lock();

// Reset global config to defaults to avoid interference from other tests
global::reset_to_defaults();

// Test the new lock/override API for individual config values
assert_eq!(
global::get(CODEC_MAX_FRAME_LENGTH),
CODEC_MAX_FRAME_LENGTH_DEFAULT
);
assert_eq!(
global::get(MESSAGE_DELIVERY_TIMEOUT),
Duration::from_secs(30)
);

// Test single value override
{
let _guard = config.override_key(CODEC_MAX_FRAME_LENGTH, 2048);
assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 2048);
assert_eq!(
global::get(MESSAGE_DELIVERY_TIMEOUT),
Duration::from_secs(30)
); // Unchanged
}

// Values should be restored after guard is dropped
assert_eq!(
global::get(CODEC_MAX_FRAME_LENGTH),
CODEC_MAX_FRAME_LENGTH_DEFAULT
);

// Test multiple overrides
let orig_value = std::env::var("HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT").ok();
{
let _guard1 = config.override_key(CODEC_MAX_FRAME_LENGTH, 4096);
let _guard2 = config.override_key(MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(60));

assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 4096);
assert_eq!(
global::get(MESSAGE_DELIVERY_TIMEOUT),
Duration::from_secs(60)
);
// This was overridden:
assert_eq!(
std::env::var("HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT").unwrap(),
"1m"
);
}
assert_eq!(
std::env::var("HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT").ok(),
orig_value
);

// All values should be restored
assert_eq!(
global::get(CODEC_MAX_FRAME_LENGTH),
CODEC_MAX_FRAME_LENGTH_DEFAULT
);
assert_eq!(
global::get(MESSAGE_DELIVERY_TIMEOUT),
Duration::from_secs(30)
);
}
}
Loading
Loading