Skip to content

Commit 07ade88

Browse files
authored
fix(core): Fix DLPack Interface (#47)
1 parent a70ee4a commit 07ade88

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

cpp/registry.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ inline TypeTable *TypeTable::New() {
659659
self->SetFunc("mlc.core.TensorFromBytes", Func(::mlc::registry::TensorFromBytes).get());
660660
self->SetFunc("mlc.core.TensorToBase64", Func(::mlc::registry::TensorToBase64).get());
661661
self->SetFunc("mlc.core.TensorFromBase64", Func(::mlc::registry::TensorFromBase64).get());
662-
self->SetFunc("mlc.core.TensorToDLPack", Func([](TensorObj *tensor) -> void * { return tensor->DLPack(); }).get());
662+
self->SetFunc("mlc.core.TensorToDLPack", Func([](Tensor tensor) -> void * { return tensor->DLPack(); }).get());
663663
self->SetFunc("mlc.printer.DocToPythonScript", Func(::mlc::registry::DocToPythonScript).get());
664664
self->SetFunc("mlc.printer.ToPython", Func(::mlc::printer::ToPython).get());
665665

python/mlc/core/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def base64(self) -> str:
6767
def from_base64(base64: str) -> Tensor:
6868
return TensorFromBase64(base64)
6969

70-
def __dlpack__(self) -> Any:
70+
def __dlpack__(self, stream: Any = None) -> Any:
7171
return tensor_to_dlpack(self)
7272

7373
def __dlpack_device__(self) -> tuple[int, int]:

0 commit comments

Comments
 (0)