A comprehensive PyTorch-based data loading and preprocessing library for CellMap biological imaging datasets, designed for efficient machine learning training on large-scale 2D/3D volumetric data.
CellMap-Data is a specialized data loading utility that bridges the gap between large biological imaging datasets and machine learning frameworks. It provides efficient, memory-optimized data loading for training deep learning models on cell microscopy data, with support for multi-class segmentation, spatial transformations, and advanced augmentation techniques.
- π¬ Biological Data Optimized: Native support for multiscale biological imaging formats (OME-NGFF/Zarr)
- β‘ High-Performance Loading: Efficient data streaming with TensorStore backend and optimized PyTorch integration
- π― Flexible Target Construction: Support for multi-class segmentation with mutually exclusive class relationships
- π Advanced Augmentations: Comprehensive spatial and value transformations for robust model training
- π Smart Sampling: Weighted sampling strategies and validation set management
- π Scalable Architecture: Memory-efficient handling of datasets larger than available RAM
- π§ Production Ready: Thread-safe, multiprocess-compatible with extensive test coverage
pip install cellmap-data
CellMap-Data leverages several powerful libraries:
- PyTorch: Neural network training and tensor operations
- TensorStore: High-performance array storage and retrieval
- Xarray: Labeled multi-dimensional arrays with metadata
- PyDantic: Data validation and settings management
- Zarr: Chunked, compressed array storage
from cellmap_data import CellMapDataset
# Define input and target array specifications
input_arrays = {
"raw": {
"shape": (64, 64, 64), # Training patch size
"scale": (8, 8, 8), # Voxel resolution in nm
}
}
target_arrays = {
"segmentation": {
"shape": (64, 64, 64),
"scale": (8, 8, 8),
}
}
# Create dataset
dataset = CellMapDataset(
raw_path="/path/to/raw/data.zarr",
target_path="/path/to/labels/data.zarr",
classes=["mitochondria", "endoplasmic_reticulum", "nucleus"],
input_arrays=input_arrays,
target_arrays=target_arrays,
is_train=True
)
from cellmap_data import CellMapDataLoader
from cellmap_data.transforms import Normalize, RandomContrast, GaussianNoise, Binarize
import torchvision.transforms.v2 as T
# Define spatial transformations
spatial_transforms = {
"mirror": {"axes": {"x": 0.5, "y": 0.5, "z": 0.2}},
"rotate": {"axes": {"z": [-30, 30]}},
"transpose": {"axes": ["x", "y"]}
}
# Define value transformations
raw_value_transforms = T.Compose([
Normalize(scale=1/255), # Normalize to [0,1]
GaussianNoise(std=0.05), # Add noise for augmentation
RandomContrast((0.8, 1.2)), # Vary contrast
])
target_value_transforms = T.Compose([
Binarize(threshold=0.5), # Convert to binary masks
T.ToDtype(torch.float32) # Ensure correct dtype
])
# Create dataset with transforms
dataset = CellMapDataset(
raw_path="/path/to/raw/data.zarr",
target_path="/path/to/labels/data.zarr",
classes=["mitochondria", "endoplasmic_reticulum", "nucleus"],
input_arrays=input_arrays,
target_arrays=target_arrays,
spatial_transforms=spatial_transforms,
raw_value_transforms=raw_value_transforms,
target_value_transforms=target_value_transforms,
is_train=True
)
# Configure data loader
loader = CellMapDataLoader(
dataset,
batch_size=4,
num_workers=8,
weighted_sampler=True, # Balance classes automatically
is_train=True
)
# Training loop
for batch in loader:
inputs = batch["raw"] # Shape: [batch, channels, z, y, x]
targets = batch["segmentation"] # Multi-class targets
# Your training code here
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
from cellmap_data import CellMapDataSplit
# Define datasets from CSV or dictionary
datasplit = CellMapDataSplit(
csv_path="path/to/datasplit.csv",
classes=["mitochondria", "er", "nucleus"],
input_arrays=input_arrays,
target_arrays=target_arrays,
spatial_transforms={
"mirror": {"axes": {"x": 0.5, "y": 0.5}},
"rotate": {"axes": {"z": [-180, 180]}},
"transpose": {"axes": ["x", "y"]}
}
)
# Access combined datasets
train_loader = CellMapDataLoader(
datasplit.train_datasets_combined,
batch_size=8,
weighted_sampler=True
)
val_loader = CellMapDataLoader(
datasplit.validation_datasets_combined,
batch_size=16,
is_train=False
)
The foundational dataset class that handles individual image volumes:
dataset = CellMapDataset(
raw_path="path/to/raw.zarr",
target_path="path/to/gt.zarr",
classes=["class1", "class2"],
input_arrays=input_arrays,
target_arrays=target_arrays,
is_train=True,
pad=True, # Pad arrays to requested size if needed
device="cuda"
)
Key Features:
- Automatic 2D/3D handling and slicing
- Multiscale data support
- Memory-efficient random cropping
- Class balancing and weighting
- Spatial transformation pipeline
Combines multiple datasets for training across different samples:
from cellmap_data import CellMapMultiDataset
multi_dataset = CellMapMultiDataset(
classes=classes,
input_arrays=input_arrays,
target_arrays=target_arrays,
datasets=[dataset1, dataset2, dataset3]
)
# Weighted sampling across datasets
sampler = multi_dataset.get_weighted_sampler(batch_size=4)
High-performance data loader with optimization features:
loader = CellMapDataLoader(
dataset,
batch_size=16,
num_workers=12,
weighted_sampler=True,
device="cuda",
iterations_per_epoch=1000 # For large datasets
)
# Optimized GPU memory transfer
loader.to("cuda", non_blocking=True)
Optimizations:
- CUDA streams for parallel GPU transfer
- Persistent workers for reduced overhead
- Automatic memory estimation and optimization
- Thread-safe multiprocessing
Manages train/validation splits with configuration:
datasplit = CellMapDataSplit(
dataset_dict={
"train": [
{"raw": "path1/raw.zarr", "gt": "path1/gt.zarr"},
{"raw": "path2/raw.zarr", "gt": "path2/gt.zarr"}
],
"validate": [
{"raw": "path3/raw.zarr", "gt": "path3/gt.zarr"}
]
},
classes=classes,
input_arrays=input_arrays,
target_arrays=target_arrays
)
Comprehensive augmentation pipeline for robust training:
spatial_transforms = {
"mirror": {
"axes": {"x": 0.5, "y": 0.5, "z": 0.1} # Probability per axis
},
"rotate": {
"axes": {"z": [-45, 45], "y": [-15, 15]} # Angle ranges
},
"transpose": {
"axes": ["x", "y"] # Axes to randomly reorder
}
}
Built-in preprocessing and augmentation transforms:
from cellmap_data.transforms import (
Normalize, GaussianNoise, RandomContrast,
RandomGamma, Binarize, NaNtoNum, GaussianBlur
)
# Input preprocessing
raw_transforms = T.Compose([
Normalize(scale=1/255), # Normalize to [0,1]
GaussianNoise(std=0.1), # Add noise
RandomContrast((0.8, 1.2)), # Vary contrast
NaNtoNum({"nan": 0}) # Handle NaN values
])
# Target preprocessing
target_transforms = T.Compose([
Binarize(threshold=0.5), # Convert to binary
T.ToDtype(torch.float32) # Ensure float32
])
Support for mutually exclusive classes and true negative inference:
# Define class relationships
class_relation_dict = {
"mitochondria": ["cytoplasm", "nucleus"], # Mutually exclusive
"endoplasmic_reticulum": ["mitochondria"], # Cannot overlap
}
dataset = CellMapDataset(
# ... other parameters ...
classes=["mitochondria", "er", "nucleus", "cytoplasm"],
class_relation_dict=class_relation_dict,
# True negatives automatically inferred from relationships
)
For datasets larger than available memory:
# Use subset sampling for large datasets
loader = CellMapDataLoader(
large_dataset,
batch_size=8,
iterations_per_epoch=5000, # Subsample each epoch
weighted_sampler=True
)
# Refresh sampler between epochs
for epoch in range(num_epochs):
loader.refresh() # New random subset
for batch in loader:
# Training code
...
Generate predictions and write to disk efficiently:
from cellmap_data import CellMapDatasetWriter
writer = CellMapDatasetWriter(
raw_path="input.zarr",
target_path="predictions.zarr",
classes=["class1", "class2"],
input_arrays=input_arrays,
target_arrays=target_arrays,
target_bounds={"array": {"x": [0, 1000], "y": [0, 1000], "z": [0, 100]}}
)
# Write predictions tile by tile
for idx in range(len(writer)):
inputs = writer[idx]
predictions = model(inputs)
writer[idx] = {"segmentation": predictions}
- OME-NGFF/Zarr: Primary format with multiscale support and full read/write capabilities
- Local/S3/GCS: Various storage backends via TensorStore
Automatic handling of multiscale datasets:
# Automatically selects appropriate scale level
dataset = CellMapDataset(
raw_path="data.zarr", # Contains s0, s1, s2, ... scale levels
target_path="labels.zarr",
# ... other parameters ...
)
# Multiscale input arrays can be specified
input_arrays = {
"raw_4nm": {
"shape": (128, 128, 128),
"scale": (4, 4, 4),
},
"raw_8nm": {
"shape": (64, 64, 64),
"scale": (8, 8, 8),
}
}
- Efficient tensor operations with minimal copying
- Automatic GPU memory management
- Streaming data loading for large volumes
- Multi-threaded data loading
- CUDA streams for GPU optimization
- Process-safe dataset pickling
- Persistent ThreadPoolExecutor for reduced overhead
- Optimized coordinate transformations
- Minimal redundant computations
# Multi-class cell segmentation
classes = ["cell_boundary", "mitochondria", "nucleus", "er"]
spatial_transforms = {
"mirror": {"axes": {"x": 0.5, "y": 0.5}},
"rotate": {"axes": {"z": [-180, 180]}}
}
dataset = CellMapDataset(
raw_path="em_data.zarr",
target_path="segmentation_labels.zarr",
classes=classes,
input_arrays={"em": {"shape": (128, 128, 128), "scale": (4, 4, 4)}},
target_arrays={"labels": {"shape": (128, 128, 128), "scale": (4, 4, 4)}},
spatial_transforms=spatial_transforms,
is_train=True
)
# Training across multiple biological samples
datasplit = CellMapDataSplit(
csv_path="multi_sample_split.csv",
classes=organelle_classes,
input_arrays=input_config,
target_arrays=target_config,
spatial_transforms=augmentation_config
)
# Balanced sampling across datasets
train_loader = CellMapDataLoader(
datasplit.train_datasets_combined,
batch_size=16,
weighted_sampler=True,
num_workers=16
)
# Generate predictions on new data
writer = CellMapDatasetWriter(
raw_path="new_sample.zarr",
target_path="predictions.zarr",
classes=trained_classes,
input_arrays=inference_config,
target_arrays=output_config,
target_bounds=volume_bounds
)
# Process in tiles
for idx in writer.writer_indices: # Non-overlapping tiles
batch = writer[idx]
with torch.no_grad():
predictions = model(batch["input"])
writer[idx] = {"segmentation": predictions}
- Choose patch sizes that fit comfortably in GPU memory
- Enable padding for datasets smaller than patch size
- Use weighted sampling for imbalanced datasets
- Configure appropriate number of workers (typically 2x CPU cores)
- Enable CUDA streams for multi-GPU setups
- Monitor memory usage with large datasets
- Use iterations_per_epoch for very large datasets
- Refresh samplers between epochs for dataset variety
- Start with small patch sizes and single workers
- Use force_has_data=True for testing with empty datasets
- Check dataset.verify() before training
For complete API documentation, visit: https://janelia-cellmap.github.io/cellmap-data/
We welcome contributions! Please see our contributing guidelines for details on:
- Code style and standards
- Testing requirements
- Documentation expectations
- Pull request process
If you use CellMap-Data in your research, please cite:
@software{cellmap_data,
title={CellMap-Data: PyTorch Data Loading for Biological Imaging},
author={Rhoades, Jeff and the CellMap Team},
url={https://github.com/janelia-cellmap/cellmap-data},
year={2024}
}
This project is licensed under the BSD 3-Clause License - see the LICENSE file for details.
- π Documentation
- π Issue Tracker
- π¬ Discussions
- π§ Contact: [email protected]