diff --git a/python/mlc/_cython/base.py b/python/mlc/_cython/base.py index dee51191..362f4174 100644 --- a/python/mlc/_cython/base.py +++ b/python/mlc/_cython/base.py @@ -377,6 +377,10 @@ def fget(this: typing.Any, _name: str = name) -> typing.Any: def fset(this: typing.Any, value: typing.Any, _name: str = name) -> None: setter(this, value) # type: ignore[misc] + fget.__name__ = fset.__name__ = name + fget.__module__ = fset.__module__ = cls.__module__ + fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}" # type: ignore[attr-defined] + fget.__doc__ = fset.__doc__ = f"Property `{name}` of class `{cls.__qualname__}`" # type: ignore[attr-defined] prop = property( fget=fget if getter else None, fset=fset if (not frozen) and setter else None, diff --git a/python/mlc/core/object.py b/python/mlc/core/object.py index 821ea57b..d741f573 100644 --- a/python/mlc/core/object.py +++ b/python/mlc/core/object.py @@ -2,7 +2,7 @@ import typing -from mlc._cython import PyAny, c_class_core +from mlc._cython import PyAny, TypeInfo, c_class_core @c_class_core("object.Object") @@ -65,6 +65,25 @@ def __eq__(self, other: typing.Any) -> bool: def __ne__(self, other: typing.Any) -> bool: return not self == other + def _mlc_setattr(self, name: str, value: typing.Any) -> None: + type_info: TypeInfo = type(self)._mlc_type_info + for field in type_info.fields: + if field.name == name: + if field.setter is None: + raise AttributeError(f"Attribute `{name}` missing setter") + field.setter(self, value) + return + raise AttributeError(f"Attribute `{name}` not found in `{type(self)}`") + + def _mlc_getattr(self, name: str) -> typing.Any: + type_info: TypeInfo = type(self)._mlc_type_info + for field in type_info.fields: + if field.name == name: + if field.getter is None: + raise AttributeError(f"Attribute `{name}` missing getter") + return field.getter(self) + raise AttributeError(f"Attribute `{name}` not found in `{type(self)}`") + def swap(self, other: typing.Any) -> None: if type(self) == type(other): self._mlc_swap(other) diff --git a/python/mlc/dataclasses/c_class.py b/python/mlc/dataclasses/c_class.py index e13caac9..0bc70c1f 100644 --- a/python/mlc/dataclasses/c_class.py +++ b/python/mlc/dataclasses/c_class.py @@ -39,7 +39,7 @@ class type_cls(super_type_cls): # type: ignore[valid-type,misc] if type_info.type_cls is not None: raise ValueError(f"Type is already registered: {type_key}") - _, d_fields = inspect_dataclass_fields(type_key, type_cls, parent_type_info) + _, d_fields = inspect_dataclass_fields(type_key, type_cls, parent_type_info, frozen=False) type_info.type_cls = type_cls type_info.d_fields = tuple(d_fields) diff --git a/python/mlc/dataclasses/py_class.py b/python/mlc/dataclasses/py_class.py index 5bf99f28..d066ec35 100644 --- a/python/mlc/dataclasses/py_class.py +++ b/python/mlc/dataclasses/py_class.py @@ -49,6 +49,7 @@ def py_class( *, init: bool = True, repr: bool = True, + frozen: bool = False, structure: typing.Literal["bind", "nobind", "var"] | None = None, ) -> Callable[[type[ClsType]], type[ClsType]]: if isinstance(type_key, type): @@ -86,6 +87,7 @@ def decorator(super_type_cls: type[ClsType]) -> type[ClsType]: type_key, super_type_cls, parent_type_info, + frozen=frozen, ) num_bytes = _add_field_properties(fields) type_info.fields = tuple(fields) diff --git a/python/mlc/dataclasses/utils.py b/python/mlc/dataclasses/utils.py index 48d142f0..6ab0e5de 100644 --- a/python/mlc/dataclasses/utils.py +++ b/python/mlc/dataclasses/utils.py @@ -124,6 +124,7 @@ def inspect_dataclass_fields( # noqa: PLR0912 type_key: str, type_cls: type, parent_type_info: TypeInfo, + frozen: bool, ) -> tuple[list[TypeField], list[Field]]: def _get_num_bytes(field_ty: Any) -> int: if hasattr(field_ty, "_ctype"): @@ -136,6 +137,7 @@ def _get_num_bytes(field_ty: Any) -> int: for type_field in parent_type_info.fields: field_name = type_field.name field_ty = type_field.ty + field_frozen = type_field.frozen if type_hints.pop(field_name, None) is None: raise ValueError( f"Missing field `{type_key}::{field_name}`, " @@ -146,7 +148,7 @@ def _get_num_bytes(field_ty: Any) -> int: name=field_name, offset=-1, num_bytes=_get_num_bytes(field_ty), - frozen=False, + frozen=field_frozen, ty=field_ty, ) ) @@ -159,7 +161,7 @@ def _get_num_bytes(field_ty: Any) -> int: name=field_name, offset=-1, num_bytes=_get_num_bytes(field_ty), - frozen=False, + frozen=frozen, ty=field_ty, ) ) diff --git a/tests/python/test_dataclasses_py_class.py b/tests/python/test_dataclasses_py_class.py index 3284353a..8e51ded2 100644 --- a/tests/python/test_dataclasses_py_class.py +++ b/tests/python/test_dataclasses_py_class.py @@ -2,6 +2,7 @@ import mlc import mlc.dataclasses as mlcd +import pytest @mlcd.py_class("mlc.testing.py_class_base") @@ -57,6 +58,12 @@ def __post_init__(self) -> None: self.b = self.b.upper() +@mlcd.py_class("mlc.testing.py_class_frozen", frozen=True) +class Frozen(mlcd.PyClass): + a: int + b: str + + def test_base() -> None: base = Base(1, "a") base_str = "mlc.testing.py_class_base(base_a=1, base_b='a')" @@ -120,6 +127,30 @@ def test_post_init() -> None: assert repr(post_init) == "mlc.testing.py_class_post_init(a=1, b='A')" +def test_frozen_set_fail() -> None: + frozen = Frozen(1, "a") + with pytest.raises(AttributeError) as e: + frozen.a = 2 + # depends on Python version, there are a few possible error messages + assert str(e.value) in [ + "property 'a' of 'Frozen' object has no setter", + "can't set attribute", + ] + assert frozen.a == 1 + assert frozen.b == "a" + + +def test_frozen_force_set() -> None: + frozen = Frozen(1, "a") + frozen._mlc_setattr("a", 2) + assert frozen.a == 2 + assert frozen.b == "a" + + frozen._mlc_setattr("b", "b") + assert frozen.a == 2 + assert frozen.b == "b" + + def test_derived_derived() -> None: # __init__(base_a, derived_derived_a, base_b, derived_a, derived_b) obj = DerivedDerived(1, "a", [1, 2], 2, "b")