diff --git a/jvm/core/src/main/java/org/apache/tvm/Function.java b/jvm/core/src/main/java/org/apache/tvm/Function.java index df535a87aa85..594b35b0af68 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Function.java +++ b/jvm/core/src/main/java/org/apache/tvm/Function.java @@ -222,6 +222,16 @@ public Function pushArg(byte[] arg) { return this; } + /** + * Push argument to the function. + * @param arg Device. + * @return this + */ + public Function pushArg(Device arg) { + Base._LIB.tvmFuncPushArgDevice(arg); + return this; + } + /** * Invoke function with arguments. * @param args Can be Integer, Long, Float, Double, String, NDArray. @@ -255,6 +265,8 @@ private static void pushArgToStack(Object arg) { Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id); } else if (arg instanceof Function) { Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id); + } else if (arg instanceof Device) { + Base._LIB.tvmFuncPushArgDevice((Device) arg); } else if (arg instanceof TVMValue) { TVMValue tvmArg = (TVMValue) arg; switch (tvmArg.typeCode) { diff --git a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java index 62b8c901bd71..aede9be334c8 100644 --- a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java +++ b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java @@ -37,6 +37,8 @@ class LibInfo { native void tvmFuncPushArgHandle(long arg, int argType); + native void tvmFuncPushArgDevice(Device device); + native int tvmFuncListGlobalNames(List funcNames); native int tvmFuncFree(long handle); diff --git a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java index 737fdef24ae8..0a0bc7efc46d 100644 --- a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java +++ b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java @@ -41,7 +41,7 @@ public class GraphModule { private Function fdebugGetOutput; private Function floadParams; - GraphModule(Module module, Device dev) { + public GraphModule(Module module, Device dev) { this.module = module; this.device = dev; fsetInput = module.getFunction("set_input"); diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index d60a1a4230b7..3e44f757392d 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -214,4 +214,25 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) { return NULL; } +// Helper function to pack two int32_t values into an int64_t +inline int64_t deviceToInt64(const int32_t device_type, const int32_t device_id) { + int64_t result; + int32_t* parts = reinterpret_cast(&result); + + // Lambda function to check endianness + const auto isLittleEndian = []() -> bool { + uint32_t i = 1; + return *reinterpret_cast(&i) == 1; + }; + + if (isLittleEndian()) { + parts[0] = device_type; + parts[1] = device_id; + } else { + parts[1] = device_type; + parts[0] = device_id; + } + return result; +} + #endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_ diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index 09522381f181..c039508b4b7f 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -112,6 +112,21 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv* e->tvmFuncArgTypes.push_back(static_cast(argType)); } +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDevice(JNIEnv* env, jobject obj, + jobject arg) { + jclass deviceClass = env->FindClass("org/apache/tvm/Device"); + jfieldID deviceTypeField = env->GetFieldID(deviceClass, "deviceType", "I"); + jfieldID deviceIdField = env->GetFieldID(deviceClass, "deviceId", "I"); + jint deviceType = env->GetIntField(arg, deviceTypeField); + jint deviceId = env->GetIntField(arg, deviceIdField); + + TVMValue value; + value.v_int64 = deviceToInt64(deviceType, deviceId); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); + e->tvmFuncArgValues.push_back(value); + e->tvmFuncArgTypes.push_back(kDLDevice); +} + JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj, jbyteArray arg) { jbyteArray garg = reinterpret_cast(env->NewGlobalRef(arg));