Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions jvm/core/src/main/java/org/apache/tvm/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,16 @@ public Function pushArg(byte[] arg) {
return this;
}

/**
* Push argument to the function.
* @param arg Device.
* @return this
*/
public Function pushArg(Device arg) {
Base._LIB.tvmFuncPushArgDevice(arg);
return this;
}

/**
* Invoke function with arguments.
* @param args Can be Integer, Long, Float, Double, String, NDArray.
Expand Down Expand Up @@ -255,6 +265,8 @@ private static void pushArgToStack(Object arg) {
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id);
} else if (arg instanceof Function) {
Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id);
} else if (arg instanceof Device) {
Base._LIB.tvmFuncPushArgDevice((Device) arg);
} else if (arg instanceof TVMValue) {
TVMValue tvmArg = (TVMValue) arg;
switch (tvmArg.typeCode) {
Expand Down
2 changes: 2 additions & 0 deletions jvm/core/src/main/java/org/apache/tvm/LibInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class LibInfo {

native void tvmFuncPushArgHandle(long arg, int argType);

native void tvmFuncPushArgDevice(Device device);

native int tvmFuncListGlobalNames(List<String> funcNames);

native int tvmFuncFree(long handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class GraphModule {
private Function fdebugGetOutput;
private Function floadParams;

GraphModule(Module module, Device dev) {
public GraphModule(Module module, Device dev) {
this.module = module;
this.device = dev;
fsetInput = module.getFunction("set_input");
Expand Down
21 changes: 21 additions & 0 deletions jvm/native/src/main/native/jni_helper_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,25 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) {
return NULL;
}

// Helper function to pack two int32_t values into an int64_t
inline int64_t deviceToInt64(const int32_t device_type, const int32_t device_id) {
int64_t result;
int32_t* parts = reinterpret_cast<int32_t*>(&result);

// Lambda function to check endianness
const auto isLittleEndian = []() -> bool {
uint32_t i = 1;
return *reinterpret_cast<char*>(&i) == 1;
};

if (isLittleEndian()) {
parts[0] = device_type;
parts[1] = device_id;
} else {
parts[1] = device_type;
parts[0] = device_id;
}
return result;
}

#endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
15 changes: 15 additions & 0 deletions jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,21 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv*
e->tvmFuncArgTypes.push_back(static_cast<int>(argType));
}

JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDevice(JNIEnv* env, jobject obj,
jobject arg) {
jclass deviceClass = env->FindClass("org/apache/tvm/Device");
jfieldID deviceTypeField = env->GetFieldID(deviceClass, "deviceType", "I");
jfieldID deviceIdField = env->GetFieldID(deviceClass, "deviceId", "I");
jint deviceType = env->GetIntField(arg, deviceTypeField);
jint deviceId = env->GetIntField(arg, deviceIdField);

TVMValue value;
value.v_int64 = deviceToInt64(deviceType, deviceId);
TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kDLDevice);
}

JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj,
jbyteArray arg) {
jbyteArray garg = reinterpret_cast<jbyteArray>(env->NewGlobalRef(arg));
Expand Down
Loading