JAX prefers immutable objects but neither Python nor JAX provide an immutable
dictionary. 😢
This repository defines a light-weight immutable map
(lower-level than a dict) that JAX understands as a PyTree. 🎉 🕶️
pip install xmmutablemap
using uv
uv add xmmutablemap
from source, using pip
pip install git+https://github.com/GalacticDynamics/xmmutablemap.git
building from source
cd /path/to/parent
git clone https://github.com/GalacticDynamics/xmmutablemap.git
cd xmmutablemap
pip install -e . # editable mode
xmutablemap
provides the class ImmutableMap
, which is a full implementation
of
Python's Mapping
ABC.
If you've used a dict
then you already know how to use ImmutableMap
! The
things ImmutableMap
adds is 1) immutability (and related benefits like
hashability) and 2) compatibility with JAX
.
from xmmutablemap import ImmutableMap
print(ImmutableMap(a=1, b=2, c=3))
# ImmutableMap({'a': 1, 'b': 2, 'c': 3})
print(ImmutableMap({"a": 1, "b": 2.0, "c": "3"}))
# ImmutableMap({'a': 1, 'b': 2.0, 'c': '3'})
One of the key benefits of ImmutableMap
is its compatibility with JAX. Since
it's immutable and hashable, it can be used in places where JAX would normally
complain about mutable objects like regular dictionaries.
Here's an example showing how ImmutableMap
can be used as a default value in a
dataclass, which is particularly useful with JAX:
import functools
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from xmmutablemap import ImmutableMap
@functools.partial(
jax.tree_util.register_dataclass, data_fields=["params"], meta_fields=["batch_size"]
)
@dataclass(frozen=True)
class Config:
"""Configuration with immutable default parameters."""
# This works! ImmutableMap is immutable and hashable
params: ImmutableMap[str, float] = ImmutableMap(
learning_rate=0.001, momentum=0.9, weight_decay=1e-4
)
batch_size: int = 32
# JAX can safely transform functions using this dataclass
@jax.jit
def train_step(config: Config, data: jnp.ndarray) -> jnp.ndarray:
"""Example training step that uses config parameters."""
lr = config.params["learning_rate"]
return data * lr
# This works perfectly
config = Config()
data = jnp.array([1.0, 2.0, 3.0])
result = train_step(config, data)
print(f"Result: {result}")
# Result: [0.001 0.002 0.003]
- Immutability: Once created,
ImmutableMap
cannot be modified, preventing accidental mutations that could break JAX's functional programming model - Hashability: JAX can safely cache and memoize functions that use
ImmutableMap
instances - PyTree Support:
ImmutableMap
is registered as a JAX PyTree, so it works seamlessly with JAX transformations likejit
,grad
,vmap
, etc. - Safe Defaults: Can be used as default values in dataclasses without the typical pitfalls of mutable defaults
We welcome contributions!