diff --git a/hyperactor_mesh/src/v1/host_mesh.rs b/hyperactor_mesh/src/v1/host_mesh.rs index 554fada9c..f951df479 100644 --- a/hyperactor_mesh/src/v1/host_mesh.rs +++ b/hyperactor_mesh/src/v1/host_mesh.rs @@ -511,23 +511,11 @@ impl HostMeshRef { ))); } - let labels = self - .region - .labels() - .to_vec() - .into_iter() - .chain(per_host.labels().to_vec().into_iter()) - .collect(); - let sizes = self + let extent = self .region .extent() - .sizes() - .to_vec() - .into_iter() - .chain(per_host.sizes().to_vec().into_iter()) - .collect(); - let extent = - Extent::new(labels, sizes).map_err(|err| v1::Error::ConfigurationError(err.into()))?; + .concat(&per_host) + .map_err(|err| v1::Error::ConfigurationError(err.into()))?; let mesh_name = Name::new(name); let mut procs = Vec::new(); diff --git a/ndslice/src/view.rs b/ndslice/src/view.rs index 643d269d0..4c2555f68 100644 --- a/ndslice/src/view.rs +++ b/ndslice/src/view.rs @@ -59,6 +59,16 @@ pub enum ExtentError { /// Number of dimension sizes provided. num_sizes: usize, }, + + /// An overlapping label was found. + /// + /// This occurs when attempting to combine extents that + /// share one or more dimension labels, which is not allowed. + #[error("overlapping label found: {label}")] + OverlappingLabel { + /// The label that appears in both extents. + label: String, + }, } /// `Extent` defines the logical shape of a multidimensional space by @@ -307,6 +317,30 @@ impl Extent { pub fn points(&self) -> ExtentPointsIterator<'_> { ExtentPointsIterator::new(self) } + + /// Append the dimensions of `other` to this extent, preserving order. + /// + /// Duplicate labels are not allowed: if any label in `other` already appears + /// in `self`, this returns `ExtentError::OverlappingLabel`. + /// + /// This operation is not commutative: `a.concat(&b)` may differ from + /// `b.concat(&a)`. + pub fn concat(&self, other: &Extent) -> Result { + use std::collections::HashSet; + // Check for any overlapping labels in linear time using hash set + let lhs: HashSet<&str> = self.labels().iter().map(|s| s.as_str()).collect(); + if let Some(dup) = other.labels().iter().find(|l| lhs.contains(l.as_str())) { + return Err(ExtentError::OverlappingLabel { label: dup.clone() }); + } + // Combine labels and sizes from both extents with pre-allocated memory + let mut labels = self.labels().to_vec(); + let mut sizes = self.sizes().to_vec(); + labels.reserve(other.labels().len()); + sizes.reserve(other.sizes().len()); + labels.extend(other.labels().iter().cloned()); + sizes.extend(other.sizes().iter().copied()); + Extent::new(labels, sizes) + } } /// Label formatting utilities shared across `Extent`, `Region`, and @@ -2111,6 +2145,99 @@ mod test { assert!(it.next().is_none()); // fused } + #[test] + fn test_extent_concat() { + // Test basic concatenation of two extents with preserved order of labels + let extent1 = extent!(x = 2, y = 3); + let extent2 = extent!(z = 4, w = 5); + + let result = extent1.concat(&extent2).unwrap(); + assert_eq!(result.labels(), &["x", "y", "z", "w"]); + assert_eq!(result.sizes(), &[2, 3, 4, 5]); + assert_eq!(result.num_ranks(), 2 * 3 * 4 * 5); + + // Test concatenating with empty extent + let empty = extent!(); + let result = extent1.concat(&empty).unwrap(); + assert_eq!(result.labels(), &["x", "y"]); + assert_eq!(result.sizes(), &[2, 3]); + + let result = empty.concat(&extent1).unwrap(); + assert_eq!(result.labels(), &["x", "y"]); + assert_eq!(result.sizes(), &[2, 3]); + + // Test concatenating two empty extents + let result = empty.concat(&empty).unwrap(); + assert_eq!(result.labels(), &[] as &[String]); + assert_eq!(result.sizes(), &[] as &[usize]); + assert_eq!(result.num_ranks(), 1); // 0-dimensional extent has 1 rank + + // Test self-concatenation (overlapping labels should cause error) + let result = extent1.concat(&extent1); + assert!( + result.is_err(), + "Self-concatenation should error due to overlapping labels" + ); + match result.unwrap_err() { + ExtentError::OverlappingLabel { label } => { + assert!(label == "x"); // Overlapping label should be "x" + } + other => panic!("Expected OverlappingLabel error, got {:?}", other), + } + + // Test concatenation creates valid points + let result = extent1.concat(&extent2).unwrap(); + let point = result.point(vec![1, 2, 3, 4]).unwrap(); + assert_eq!(point.coords(), vec![1, 2, 3, 4]); + assert_eq!(point.extent(), &result); + + // Test error case: overlapping labels with same size (should error) + let extent_a = extent!(x = 2, y = 3); + let extent_b = extent!(y = 3, z = 4); // y overlaps with same size + let result = extent_a.concat(&extent_b); + assert!( + result.is_err(), + "Should error on overlapping labels even with same size" + ); + match result.unwrap_err() { + ExtentError::OverlappingLabel { label } => { + assert_eq!(label, "y"); // the overlapping label + } + other => panic!("Expected OverlappingLabel error, got {:?}", other), + } + + // Test that Extent::concat preserves order and is not commutative + let extent_x = extent!(x = 2, y = 3); + let extent_y = extent!(z = 4); + assert_eq!( + extent_x.concat(&extent_y).unwrap().labels(), + &["x", "y", "z"] + ); + assert_eq!( + extent_y.concat(&extent_x).unwrap().labels(), + &["z", "x", "y"] + ); + + // Test associativity: (a ⊕ b) ⊕ c == a ⊕ (b ⊕ c) for disjoint labels + let extent_m = extent!(x = 2); + let extent_n = extent!(y = 3); + let extent_o = extent!(z = 4); + + let left_assoc = extent_m + .concat(&extent_n) + .unwrap() + .concat(&extent_o) + .unwrap(); + let right_assoc = extent_m + .concat(&extent_n.concat(&extent_o).unwrap()) + .unwrap(); + + assert_eq!(left_assoc, right_assoc); + assert_eq!(left_assoc.labels(), &["x", "y", "z"]); + assert_eq!(left_assoc.sizes(), &[2, 3, 4]); + assert_eq!(left_assoc.num_ranks(), 2 * 3 * 4); + } + #[test] fn extent_unity_equiv_to_0d() { let e = Extent::unity();