diff --git a/cpp/registry.h b/cpp/registry.h index 3add388a..03384f1d 100644 --- a/cpp/registry.h +++ b/cpp/registry.h @@ -659,7 +659,7 @@ inline TypeTable *TypeTable::New() { self->SetFunc("mlc.core.TensorFromBytes", Func(::mlc::registry::TensorFromBytes).get()); self->SetFunc("mlc.core.TensorToBase64", Func(::mlc::registry::TensorToBase64).get()); self->SetFunc("mlc.core.TensorFromBase64", Func(::mlc::registry::TensorFromBase64).get()); - self->SetFunc("mlc.core.TensorToDLPack", Func([](TensorObj *tensor) -> void * { return tensor->DLPack(); }).get()); + self->SetFunc("mlc.core.TensorToDLPack", Func([](Tensor tensor) -> void * { return tensor->DLPack(); }).get()); self->SetFunc("mlc.printer.DocToPythonScript", Func(::mlc::registry::DocToPythonScript).get()); self->SetFunc("mlc.printer.ToPython", Func(::mlc::printer::ToPython).get()); diff --git a/python/mlc/core/tensor.py b/python/mlc/core/tensor.py index 8045bc8f..cb0c1b67 100644 --- a/python/mlc/core/tensor.py +++ b/python/mlc/core/tensor.py @@ -67,7 +67,7 @@ def base64(self) -> str: def from_base64(base64: str) -> Tensor: return TensorFromBase64(base64) - def __dlpack__(self) -> Any: + def __dlpack__(self, stream: Any = None) -> Any: return tensor_to_dlpack(self) def __dlpack_device__(self) -> tuple[int, int]: