Skip to content

Commit d7bd325

Browse files
Add missing Maskformer dataclass decorator, add dataclass check in ModelOutput for subclasses (#25638)
* Add @DataClass to MaskFormerPixelDecoderOutput * Add dataclass check if subclass of ModelOutout * Use unittest assertRaises rather than pytest per contribution doc * Update src/transformers/utils/generic.py per suggested change Co-authored-by: amyeroberts <[email protected]> --------- Co-authored-by: amyeroberts <[email protected]>
1 parent 05de038 commit d7bd325

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

src/transformers/models/maskformer/modeling_maskformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class MaskFormerPixelLevelModuleOutput(ModelOutput):
118118
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
119119

120120

121+
@dataclass
121122
class MaskFormerPixelDecoderOutput(ModelOutput):
122123
"""
123124
MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state

src/transformers/utils/generic.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from collections import OrderedDict, UserDict
2121
from collections.abc import MutableMapping
2222
from contextlib import ExitStack, contextmanager
23-
from dataclasses import fields
23+
from dataclasses import fields, is_dataclass
2424
from enum import Enum
2525
from typing import Any, ContextManager, List, Tuple
2626

@@ -314,7 +314,26 @@ def __init_subclass__(cls) -> None:
314314
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
315315
)
316316

317+
def __init__(self, *args, **kwargs):
318+
super().__init__(*args, **kwargs)
319+
320+
# Subclasses of ModelOutput must use the @dataclass decorator
321+
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
322+
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
323+
# Just need to check that the current class is not ModelOutput
324+
is_modeloutput_subclass = self.__class__ != ModelOutput
325+
326+
if is_modeloutput_subclass and not is_dataclass(self):
327+
raise TypeError(
328+
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
329+
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
330+
)
331+
317332
def __post_init__(self):
333+
"""Check the ModelOutput dataclass.
334+
335+
Only occurs if @dataclass decorator has been used.
336+
"""
318337
class_fields = fields(self)
319338

320339
# Safety and consistency checks

tests/utils/test_model_output.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,23 @@ def test_torch_pytree(self):
143143

144144
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
145145
self.assertEqual(x, unflattened_x)
146+
147+
148+
class ModelOutputTestNoDataclass(ModelOutput):
149+
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
150+
151+
a: float
152+
b: Optional[float] = None
153+
c: Optional[float] = None
154+
155+
156+
class ModelOutputSubclassTester(unittest.TestCase):
157+
def test_direct_model_output(self):
158+
# Check that direct usage of ModelOutput instantiates without errors
159+
ModelOutput({"a": 1.1})
160+
161+
def test_subclass_no_dataclass(self):
162+
# Check that a subclass of ModelOutput without @dataclass is invalid
163+
# A valid subclass is inherently tested other unit tests above.
164+
with self.assertRaises(TypeError):
165+
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)

0 commit comments

Comments
 (0)