Skip to content

GalacticDynamics/xmmutablemap

Repository files navigation

xmmutablemap

JAX-compatible Immutable Mapping

JAX prefers immutable objects but neither Python nor JAX provide an immutable dictionary. 😢
This repository defines a light-weight immutable map (lower-level than a dict) that JAX understands as a PyTree. 🎉 🕶️

Installation

PyPI platforms PyPI version

pip install xmmutablemap
using uv
uv add xmmutablemap
from source, using pip
pip install git+https://github.com/GalacticDynamics/xmmutablemap.git
building from source
cd /path/to/parent
git clone https://github.com/GalacticDynamics/xmmutablemap.git
cd xmmutablemap
pip install -e .  # editable mode

Documentation

xmutablemap provides the class ImmutableMap, which is a full implementation of Python's Mapping ABC. If you've used a dict then you already know how to use ImmutableMap! The things ImmutableMap adds is 1) immutability (and related benefits like hashability) and 2) compatibility with JAX.

from xmmutablemap import ImmutableMap

print(ImmutableMap(a=1, b=2, c=3))
# ImmutableMap({'a': 1, 'b': 2, 'c': 3})

print(ImmutableMap({"a": 1, "b": 2.0, "c": "3"}))
# ImmutableMap({'a': 1, 'b': 2.0, 'c': '3'})

JAX Integration

One of the key benefits of ImmutableMap is its compatibility with JAX. Since it's immutable and hashable, it can be used in places where JAX would normally complain about mutable objects like regular dictionaries.

Using ImmutableMap as a Default in JAX Dataclasses

Here's an example showing how ImmutableMap can be used as a default value in a dataclass, which is particularly useful with JAX:

import functools
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from xmmutablemap import ImmutableMap


@functools.partial(
    jax.tree_util.register_dataclass, data_fields=["params"], meta_fields=["batch_size"]
)
@dataclass(frozen=True)
class Config:
    """Configuration with immutable default parameters."""

    # This works! ImmutableMap is immutable and hashable
    params: ImmutableMap[str, float] = ImmutableMap(
        learning_rate=0.001, momentum=0.9, weight_decay=1e-4
    )
    batch_size: int = 32


# JAX can safely transform functions using this dataclass
@jax.jit
def train_step(config: Config, data: jnp.ndarray) -> jnp.ndarray:
    """Example training step that uses config parameters."""
    lr = config.params["learning_rate"]
    return data * lr


# This works perfectly
config = Config()
data = jnp.array([1.0, 2.0, 3.0])
result = train_step(config, data)
print(f"Result: {result}")
# Result: [0.001 0.002 0.003]

Key Benefits for JAX

  • Immutability: Once created, ImmutableMap cannot be modified, preventing accidental mutations that could break JAX's functional programming model
  • Hashability: JAX can safely cache and memoize functions that use ImmutableMap instances
  • PyTree Support: ImmutableMap is registered as a JAX PyTree, so it works seamlessly with JAX transformations like jit, grad, vmap, etc.
  • Safe Defaults: Can be used as default values in dataclasses without the typical pitfalls of mutable defaults

Development

Actions Status

We welcome contributions!

About

Immutable Map, compatible with Jax

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Contributors 4

  •  
  •  
  •  
  •  

Languages