Skip to content

Commit bfe33cf

Browse files
authored
Merge branch 'dev' into fix-he-stain-order
2 parents c30d291 + b5bc69d commit bfe33cf

File tree

7 files changed

+221
-25
lines changed

7 files changed

+221
-25
lines changed

monai/inferers/inferer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,7 @@ def sample(
916916
verbose: bool = True,
917917
seg: torch.Tensor | None = None,
918918
cfg: float | None = None,
919+
cfg_fill_value: float = -1.0,
919920
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
920921
"""
921922
Args:
@@ -929,6 +930,7 @@ def sample(
929930
verbose: if true, prints the progression bar of the sampling process.
930931
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
931932
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
933+
cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
932934
"""
933935
if mode not in ["crossattn", "concat"]:
934936
raise NotImplementedError(f"{mode} condition is not supported")
@@ -961,7 +963,7 @@ def sample(
961963
model_input = torch.cat([image] * 2, dim=0)
962964
if conditioning is not None:
963965
uncondition = torch.ones_like(conditioning)
964-
uncondition.fill_(-1)
966+
uncondition.fill_(cfg_fill_value)
965967
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
966968
else:
967969
conditioning_input = None
@@ -1261,6 +1263,7 @@ def sample( # type: ignore[override]
12611263
verbose: bool = True,
12621264
seg: torch.Tensor | None = None,
12631265
cfg: float | None = None,
1266+
cfg_fill_value: float = -1.0,
12641267
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
12651268
"""
12661269
Args:
@@ -1276,6 +1279,7 @@ def sample( # type: ignore[override]
12761279
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
12771280
is instance of SPADEAutoencoderKL, segmentation must be provided.
12781281
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1282+
cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
12791283
"""
12801284

12811285
if (
@@ -1300,6 +1304,7 @@ def sample( # type: ignore[override]
13001304
verbose=verbose,
13011305
seg=seg,
13021306
cfg=cfg,
1307+
cfg_fill_value=cfg_fill_value,
13031308
)
13041309

13051310
if save_intermediates:
@@ -1479,6 +1484,7 @@ def sample( # type: ignore[override]
14791484
verbose: bool = True,
14801485
seg: torch.Tensor | None = None,
14811486
cfg: float | None = None,
1487+
cfg_fill_value: float = -1.0,
14821488
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
14831489
"""
14841490
Args:
@@ -1493,7 +1499,8 @@ def sample( # type: ignore[override]
14931499
mode: Conditioning mode for the network.
14941500
verbose: if true, prints the progression bar of the sampling process.
14951501
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1496-
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1502+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1503+
cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
14971504
"""
14981505
if mode not in ["crossattn", "concat"]:
14991506
raise NotImplementedError(f"{mode} condition is not supported")
@@ -1521,7 +1528,7 @@ def sample( # type: ignore[override]
15211528
model_input = torch.cat([image] * 2, dim=0)
15221529
if conditioning is not None:
15231530
uncondition = torch.ones_like(conditioning)
1524-
uncondition.fill_(-1)
1531+
uncondition.fill_(cfg_fill_value)
15251532
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
15261533
else:
15271534
conditioning_input = None
@@ -1839,6 +1846,7 @@ def sample( # type: ignore[override]
18391846
verbose: bool = True,
18401847
seg: torch.Tensor | None = None,
18411848
cfg: float | None = None,
1849+
cfg_fill_value: float = -1.0,
18421850
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
18431851
"""
18441852
Args:
@@ -1856,6 +1864,7 @@ def sample( # type: ignore[override]
18561864
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
18571865
is instance of SPADEAutoencoderKL, segmentation must be provided.
18581866
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1867+
cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
18591868
"""
18601869

18611870
if (
@@ -1884,6 +1893,7 @@ def sample( # type: ignore[override]
18841893
verbose=verbose,
18851894
seg=seg,
18861895
cfg=cfg,
1896+
cfg_fill_value=cfg_fill_value,
18871897
)
18881898

18891899
if save_intermediates:

monai/transforms/spatial/array.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
GridSamplePadMode,
6565
InterpolateMode,
6666
NumpyPadMode,
67+
SpaceKeys,
6768
convert_to_cupy,
6869
convert_to_dst_type,
6970
convert_to_numpy,
@@ -75,6 +76,7 @@
7576
issequenceiterable,
7677
optional_import,
7778
)
79+
from monai.utils.deprecate_utils import deprecated_arg_default
7880
from monai.utils.enums import GridPatchSort, PatchKeys, TraceKeys, TransformBackends
7981
from monai.utils.misc import ImageMetaKey as Key
8082
from monai.utils.module import look_up_option
@@ -556,11 +558,20 @@ class Orientation(InvertibleTransform, LazyTransform):
556558

557559
backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
558560

561+
@deprecated_arg_default(
562+
name="labels",
563+
old_default=(("L", "R"), ("P", "A"), ("I", "S")),
564+
new_default=None,
565+
msg_suffix=(
566+
"Default value changed to None meaning that the transform now uses the 'space' of a "
567+
"meta-tensor, if applicable, to determine appropriate axis labels."
568+
),
569+
)
559570
def __init__(
560571
self,
561572
axcodes: str | None = None,
562573
as_closest_canonical: bool = False,
563-
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
574+
labels: Sequence[tuple[str, str]] | None = None,
564575
lazy: bool = False,
565576
) -> None:
566577
"""
@@ -573,7 +584,14 @@ def __init__(
573584
as_closest_canonical: if True, load the image as closest to canonical axis format.
574585
labels: optional, None or sequence of (2,) sequences
575586
(2,) sequences are labels for (beginning, end) of output axis.
576-
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
587+
If ``None``, an appropriate value is chosen depending on the
588+
value of the ``"space"`` metadata item of a metatensor: if
589+
``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
590+
('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
591+
input is not a meta-tensor or has no ``"space"`` item, the
592+
value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
593+
``None``, the provided value is always used and the ``"space"``
594+
metadata item (if any) of the input is ignored.
577595
lazy: a flag to indicate whether this transform should execute lazily or not.
578596
Defaults to False
579597
@@ -619,9 +637,19 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
619637
raise ValueError(f"data_array must have at least one spatial dimension, got {spatial_shape}.")
620638
affine_: np.ndarray
621639
affine_np: np.ndarray
640+
labels = self.labels
622641
if isinstance(data_array, MetaTensor):
623642
affine_np, *_ = convert_data_type(data_array.peek_pending_affine(), np.ndarray)
624643
affine_ = to_affine_nd(sr, affine_np)
644+
645+
# Set up "labels" such that LPS tensors are handled correctly by default
646+
if (
647+
self.labels is None
648+
and "space" in data_array.meta
649+
and SpaceKeys(data_array.meta["space"]) == SpaceKeys.LPS
650+
):
651+
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS
652+
625653
else:
626654
warnings.warn("`data_array` is not of type `MetaTensor, assuming affine to be identity.")
627655
# default to identity
@@ -640,7 +668,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
640668
f"{self.__class__.__name__}: spatial shape = {spatial_shape}, channels = {data_array.shape[0]},"
641669
"please make sure the input is in the channel-first format."
642670
)
643-
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels)
671+
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=labels)
644672
if len(dst) < sr:
645673
raise ValueError(
646674
f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D"
@@ -653,8 +681,19 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
653681
transform = self.pop_transform(data)
654682
# Create inverse transform
655683
orig_affine = transform[TraceKeys.EXTRA_INFO]["original_affine"]
656-
orig_axcodes = nib.orientations.aff2axcodes(orig_affine)
657-
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=self.labels)
684+
labels = self.labels
685+
686+
# Set up "labels" such that LPS tensors are handled correctly by default
687+
if (
688+
isinstance(data, MetaTensor)
689+
and self.labels is None
690+
and "space" in data.meta
691+
and SpaceKeys(data.meta["space"]) == SpaceKeys.LPS
692+
):
693+
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS
694+
695+
orig_axcodes = nib.orientations.aff2axcodes(orig_affine, labels=labels)
696+
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=labels)
658697
# Apply inverse
659698
with inverse_transform.trace_transform(False):
660699
data = inverse_transform(data)

monai/transforms/spatial/dictionary.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
ensure_tuple_rep,
7272
fall_back_tuple,
7373
)
74+
from monai.utils.deprecate_utils import deprecated_arg_default
7475
from monai.utils.enums import TraceKeys
7576
from monai.utils.module import optional_import
7677

@@ -545,12 +546,21 @@ class Orientationd(MapTransform, InvertibleTransform, LazyTransform):
545546

546547
backend = Orientation.backend
547548

549+
@deprecated_arg_default(
550+
name="labels",
551+
old_default=(("L", "R"), ("P", "A"), ("I", "S")),
552+
new_default=None,
553+
msg_suffix=(
554+
"Default value changed to None meaning that the transform now uses the 'space' of a "
555+
"meta-tensor, if applicable, to determine appropriate axis labels."
556+
),
557+
)
548558
def __init__(
549559
self,
550560
keys: KeysCollection,
551561
axcodes: str | None = None,
552562
as_closest_canonical: bool = False,
553-
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
563+
labels: Sequence[tuple[str, str]] | None = None,
554564
allow_missing_keys: bool = False,
555565
lazy: bool = False,
556566
) -> None:
@@ -564,7 +574,14 @@ def __init__(
564574
as_closest_canonical: if True, load the image as closest to canonical axis format.
565575
labels: optional, None or sequence of (2,) sequences
566576
(2,) sequences are labels for (beginning, end) of output axis.
567-
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
577+
If ``None``, an appropriate value is chosen depending on the
578+
value of the ``"space"`` metadata item of a metatensor: if
579+
``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
580+
('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
581+
input is not a meta-tensor or has no ``"space"`` item, the
582+
value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
583+
``None``, the provided value is always used and the ``"space"``
584+
metadata item (if any) of the input is ignored.
568585
allow_missing_keys: don't raise exception if key is missing.
569586
lazy: a flag to indicate whether this transform should execute lazily or not.
570587
Defaults to False

tests/inferers/test_diffusion_inferer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def test_sample_cfg(self, model_params, input_shape):
106106
save_intermediates=True,
107107
intermediate_steps=1,
108108
cfg=5,
109+
cfg_fill_value=-1,
109110
)
110111
self.assertEqual(sample.shape, noise.shape)
111112

tests/inferers/test_latent_diffusion_inferer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def test_sample_shape_with_cfg(
456456
scheduler=scheduler,
457457
seg=input_seg,
458458
cfg=5,
459+
cfg_fill_value=-1,
459460
)
460461
else:
461462
sample = inferer.sample(

0 commit comments

Comments
 (0)