diff --git a/cpp/structure.cc b/cpp/structure.cc index 54082f6..602a368 100644 --- a/cpp/structure.cc +++ b/cpp/structure.cc @@ -538,9 +538,24 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { } else if (lhs_type_index == kMLCFunc || lhs_type_index == kMLCError) { throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", new_path); } else if (lhs_type_index == kMLCOpaque) { - std::ostringstream err; - err << "Cannot compare `mlc.Opaque` of type: " << lhs->DynCast()->opaque_type_name; - throw SEqualError(err.str().c_str(), new_path); + std::string func_name = "Opaque.eq_s."; + func_name += lhs->DynCast()->opaque_type_name; + FuncObj *func = Func::GetGlobal(func_name.c_str(), true); + if (func == nullptr) { + std::ostringstream err; + err << "Cannot compare `mlc.Opaque` of type: " << lhs->DynCast()->opaque_type_name << "; Use " + << "`mlc.Func.register(\"" << func_name << "\")(eq_s_func)` to register a comparison method"; + throw SEqualError(err.str().c_str(), new_path); + } + Any result = (*func)(lhs, rhs); + if (result.type_index != kMLCBool) { + std::ostringstream err; + err << "Comparison function `" << func_name << "` must return a boolean value, but got: " << result; + throw SEqualError(err.str().c_str(), new_path); + } + if (result.operator bool() == false) { + MLC_CORE_EQ_S_ERR(lhs, rhs, new_path); + } } else { bool visited = false; MLCTypeInfo *type_info = Lib::GetTypeInfo(lhs_type_index); @@ -802,9 +817,21 @@ inline uint64_t StructuralHashImpl(Object *obj) { } else if (type_index == kMLCFunc || type_index == kMLCError) { throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", ObjectPath::Root()); } else if (type_index == kMLCOpaque) { - std::ostringstream err; - err << "Cannot compare `mlc.Opaque` of type: " << obj->DynCast()->opaque_type_name; - throw SEqualError(err.str().c_str(), ObjectPath::Root()); + std::string func_name = "Opaque.hash_s."; + func_name += obj->DynCast()->opaque_type_name; + FuncObj *func = Func::GetGlobal(func_name.c_str(), true); + if (func == nullptr) { + MLC_THROW(ValueError) << "Cannot hash `mlc.Opaque` of type: " << obj->DynCast()->opaque_type_name + << "; Use `mlc.Func.register(\"" << func_name + << "\")(hash_s_func)` to register a hashing method"; + } + Any result = (*func)(obj); + if (result.type_index != kMLCInt) { + MLC_THROW(TypeError) << "Hashing function `" << func_name + << "` must return an integer value, but got: " << result; + } + int64_t hash_value = result.operator int64_t(); + EnqueuePOD(tasks, hash_value); } else { MLCTypeInfo *type_info = Lib::GetTypeInfo(type_index); tasks->emplace_back(Task{obj, type_info, false, bind_free_vars, type_info->type_key_hash}); diff --git a/include/mlc/core/func.h b/include/mlc/core/func.h index a474e1f..17a61d7 100644 --- a/include/mlc/core/func.h +++ b/include/mlc/core/func.h @@ -15,7 +15,7 @@ struct FuncObj : public MLCFunc { using SafeCall = int32_t(const FuncObj *, int32_t, const AnyView *, Any *); struct Allocator; - template MLC_INLINE Any operator()(Args &&...args) const { + template inline Any operator()(Args &&...args) const { constexpr size_t N = sizeof...(Args); AnyViewArray stack_args; Any ret; diff --git a/python/mlc/_cython/core.pyx b/python/mlc/_cython/core.pyx index ad41ed7..e610d35 100644 --- a/python/mlc/_cython/core.pyx +++ b/python/mlc/_cython/core.pyx @@ -687,6 +687,10 @@ cdef inline MLCAny _any_py2c(object x, list temporary_storage): y = (x)._mlc_any elif isinstance(x, Str): y = ((x._pyany))._mlc_any + elif isinstance(x, _OPAQUE_TYPES): + x = _pyany_from_opaque(x) + y = (x)._mlc_any + temporary_storage.append(x) elif isinstance(x, bool): y = _MLCAnyBool(x) elif isinstance(x, Integral): @@ -713,10 +717,6 @@ cdef inline MLCAny _any_py2c(object x, list temporary_storage): x = _pyany_from_dlpack(x) y = (x)._mlc_any temporary_storage.append(x) - elif isinstance(x, _OPAQUE_TYPES): - x = _pyany_from_opaque(x) - y = (x)._mlc_any - temporary_storage.append(x) else: raise TypeError(f"MLC does not recognize type: {type(x)}") return y diff --git a/tests/python/test_core_opaque.py b/tests/python/test_core_opaque.py index 63c3cfa..b89295d 100644 --- a/tests/python/test_core_opaque.py +++ b/tests/python/test_core_opaque.py @@ -1,20 +1,30 @@ +from typing import Any + import mlc import pytest -class MyType: +class MyTypeNotRegistered: def __init__(self, a: int) -> None: self.a = a -class MyTypeNotRegistered: +class MyType: def __init__(self, a: int) -> None: self.a = a + def __call__(self, x: int) -> int: + return x + self.a + mlc.Opaque.register(MyType) +@mlc.dataclasses.py_class(structure="bind") +class Wrapper(mlc.dataclasses.PyClass): + field: Any = mlc.dataclasses.field(structure="nobind") + + def test_opaque_init() -> None: a = MyType(a=10) opaque = mlc.Opaque(a) @@ -47,3 +57,40 @@ def test_opaque_ffi_error() -> None: str(e.value) == "MLC does not recognize type: " ) + + +def test_opaque_dataclass() -> None: + a = MyType(a=10) + wrapper = Wrapper(field=a) + assert isinstance(wrapper.field, MyType) + assert wrapper.field.a == 10 + + +@mlc.Func.register("Opaque.eq_s.test_core_opaque.MyType") +def _eq_s_MyType(a: MyType, b: MyType) -> bool: + return isinstance(a, MyType) and isinstance(b, MyType) and a.a == b.a + + +@mlc.Func.register("Opaque.hash_s.test_core_opaque.MyType") +def _hash_s_MyType(a: MyType) -> int: + assert isinstance(a, MyType) + return hash((MyType, a.a)) + + +def test_opaque_dataclass_eq_s() -> None: + a1 = Wrapper(field=MyType(a=10)) + a2 = Wrapper(field=MyType(a=10)) + a1.eq_s(a2, assert_mode=True) + + +def test_opaque_dataclass_eq_s_fail() -> None: + a1 = Wrapper(field=MyType(a=10)) + a2 = Wrapper(field=MyType(a=20)) + with pytest.raises(ValueError) as exc_info: + a1.eq_s(a2, assert_mode=True) + assert str(exc_info.value).startswith("Structural equality check failed at {root}.field") + + +def test_opaque_dataclass_hash_s() -> None: + a1 = Wrapper(field=MyType(a=10)) + assert isinstance(a1.hash_s(), int)