diff --git a/obfuscator/src/main/resources/sources/micro_vm.cpp b/obfuscator/src/main/resources/sources/micro_vm.cpp index 6debefa..f74473f 100644 --- a/obfuscator/src/main/resources/sources/micro_vm.cpp +++ b/obfuscator/src/main/resources/sources/micro_vm.cpp @@ -10,6 +10,7 @@ #include #include #include +#include // NOLINTBEGIN - obfuscated control flow by design namespace native_jvm::vm { @@ -107,6 +108,182 @@ static thread_local std::unordered_map inst static thread_local std::unordered_map static_field_cache{}; static thread_local std::unordered_map instance_field_cache{}; +enum class PrimitiveArrayKind { + BOOLEAN, + BYTE, + CHAR, + SHORT, + INT, + LONG, + FLOAT, + DOUBLE +}; + +struct PrimitiveArrayCacheEntry { + jarray array = nullptr; + void* elements = nullptr; + jsize length = 0; + bool modified = false; + PrimitiveArrayKind kind = PrimitiveArrayKind::INT; +}; + +struct PrimitiveArrayCacheHash { + size_t operator()(jarray array) const noexcept { + return std::hash{}(array); + } +}; + +class PrimitiveArrayCache { +public: + explicit PrimitiveArrayCache(JNIEnv* env) : env(env) {} + PrimitiveArrayCache(const PrimitiveArrayCache&) = delete; + PrimitiveArrayCache& operator=(const PrimitiveArrayCache&) = delete; + + ~PrimitiveArrayCache() { + release_all(); + } + + template + ElementType* get(ArrayType array, bool write, PrimitiveArrayCacheEntry** out_entry) { + if (array == nullptr) { + return nullptr; + } + auto key = reinterpret_cast(array); + auto [it, inserted] = entries.try_emplace(key); + PrimitiveArrayCacheEntry& entry = it->second; + if (inserted) { + entry.array = key; + entry.length = env->GetArrayLength(array); + entry.elements = env->GetPrimitiveArrayCritical(array, nullptr); + entry.modified = write; + entry.kind = Kind; + if (entry.elements == nullptr) { + entries.erase(it); + return nullptr; + } + } else { + if (entry.kind != Kind) { + return nullptr; + } + if (write) { + entry.modified = true; + } + } + if (out_entry != nullptr) { + *out_entry = &entry; + } + return static_cast(entry.elements); + } + + void release_all() { + for (auto& kv : entries) { + auto& entry = kv.second; + if (entry.elements != nullptr) { + env->ReleasePrimitiveArrayCritical(entry.array, entry.elements, entry.modified ? 0 : JNI_ABORT); + entry.elements = nullptr; + } + } + entries.clear(); + } + +private: + JNIEnv* env; + std::unordered_map entries{}; +}; + +static void throw_null_array(JNIEnv* env) { + jclass npe = env->FindClass("java/lang/NullPointerException"); + if (npe != nullptr) { + env->ThrowNew(npe, "null"); + env->DeleteLocalRef(npe); + } +} + +static void throw_array_index_oob(JNIEnv* env, jsize index, jsize length) { + jclass oob = env->FindClass("java/lang/ArrayIndexOutOfBoundsException"); + if (oob != nullptr) { + char buffer[96]; + std::snprintf(buffer, sizeof(buffer), "Index %d out of bounds for length %d", index, length); + env->ThrowNew(oob, buffer); + env->DeleteLocalRef(oob); + } +} + +struct ObjectArrayCacheKey { + jobjectArray array = nullptr; + jsize index = 0; + + bool operator==(const ObjectArrayCacheKey& other) const noexcept { + return array == other.array && index == other.index; + } +}; + +struct ObjectArrayCacheKeyHash { + size_t operator()(const ObjectArrayCacheKey& key) const noexcept { + size_t base = std::hash{}(key.array); + return base ^ (static_cast(key.index) << 1); + } +}; + +class ObjectArrayCache { +public: + explicit ObjectArrayCache(JNIEnv* env) : env(env) {} + ObjectArrayCache(const ObjectArrayCache&) = delete; + ObjectArrayCache& operator=(const ObjectArrayCache&) = delete; + + ~ObjectArrayCache() { + clear(); + } + + jobject get(jobjectArray array, jsize index) { + if (array == nullptr) { + throw_null_array(env); + return nullptr; + } + ObjectArrayCacheKey key{array, index}; + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + jsize length = env->GetArrayLength(array); + if (index < 0 || index >= length) { + throw_array_index_oob(env, index, length); + return nullptr; + } + jobject local = env->GetObjectArrayElement(array, index); + if (local == nullptr) { + return nullptr; + } + jobject global = env->NewGlobalRef(local); + env->DeleteLocalRef(local); + if (global == nullptr) { + return nullptr; + } + cache.emplace(key, global); + return global; + } + + void invalidate(jobjectArray array, jsize index) { + ObjectArrayCacheKey key{array, index}; + auto it = cache.find(key); + if (it != cache.end()) { + env->DeleteGlobalRef(it->second); + cache.erase(it); + } + } + +private: + void clear() { + for (auto& entry : cache) { + env->DeleteGlobalRef(entry.second); + } + cache.clear(); + } + + JNIEnv* env; + std::unordered_map cache{}; +}; + static jclass get_cached_class(JNIEnv* env, const char* name) { auto it = class_cache.find(name); if (it != class_cache.end()) { @@ -528,6 +705,8 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, uint64_t state = KEY ^ seed; OpCode op = OP_NOP; uint64_t mask = 0; + PrimitiveArrayCache array_cache(env); + ObjectArrayCache object_cache(env); goto dispatch; // start of the threaded interpreter @@ -1558,11 +1737,12 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, do_aaload: if (sp >= 2) { - int64_t index = stack[--sp]; + jsize index = static_cast(stack[--sp]); jobjectArray arr = reinterpret_cast(stack[--sp]); - jobject val = env->GetObjectArrayElement(arr, static_cast(index)); - stack[sp++] = reinterpret_cast(val); - env->DeleteLocalRef(val); + jobject val = object_cache.get(arr, index); + if (val != nullptr) { + stack[sp++] = reinterpret_cast(val); + } } goto dispatch; @@ -1572,6 +1752,7 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, jsize index = static_cast(stack[--sp]); jobjectArray arr = reinterpret_cast(stack[--sp]); env->SetObjectArrayElement(arr, index, value); + object_cache.invalidate(arr, index); } goto dispatch; @@ -1579,9 +1760,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, if (sp >= 2) { jsize index = static_cast(stack[--sp]); jintArray arr = reinterpret_cast(stack[--sp]); - jint val; - env->GetIntArrayRegion(arr, index, 1, &val); - stack[sp++] = val; + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jint* elems = array_cache.get(arr, false, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + stack[sp++] = elems[index]; + } + } + } } goto dispatch; @@ -1589,9 +1780,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, if (sp >= 2) { jsize index = static_cast(stack[--sp]); jlongArray arr = reinterpret_cast(stack[--sp]); - jlong val; - env->GetLongArrayRegion(arr, index, 1, &val); - stack[sp++] = val; + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jlong* elems = array_cache.get(arr, false, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + stack[sp++] = elems[index]; + } + } + } } goto dispatch; @@ -1599,12 +1800,21 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, if (sp >= 2) { jsize index = static_cast(stack[--sp]); jfloatArray arr = reinterpret_cast(stack[--sp]); - jfloat val; - env->GetFloatArrayRegion(arr, index, 1, &val); - // Convert float to int bits for storage - int32_t bits; - std::memcpy(&bits, &val, sizeof(float)); - stack[sp++] = static_cast(bits); + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jfloat* elems = array_cache.get(arr, false, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + int32_t bits; + std::memcpy(&bits, &elems[index], sizeof(float)); + stack[sp++] = static_cast(bits); + } + } + } } goto dispatch; @@ -1612,12 +1822,21 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, if (sp >= 2) { jsize index = static_cast(stack[--sp]); jdoubleArray arr = reinterpret_cast(stack[--sp]); - jdouble val; - env->GetDoubleArrayRegion(arr, index, 1, &val); - // Convert double to long bits for storage - int64_t bits; - std::memcpy(&bits, &val, sizeof(double)); - stack[sp++] = bits; + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jdouble* elems = array_cache.get(arr, false, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + int64_t bits; + std::memcpy(&bits, &elems[index], sizeof(double)); + stack[sp++] = bits; + } + } + } } goto dispatch; @@ -1625,9 +1844,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, if (sp >= 2) { jsize index = static_cast(stack[--sp]); jbyteArray arr = reinterpret_cast(stack[--sp]); - jbyte val; - env->GetByteArrayRegion(arr, index, 1, &val); - stack[sp++] = val; + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jbyte* elems = array_cache.get(arr, false, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + stack[sp++] = elems[index]; + } + } + } } goto dispatch; @@ -1635,9 +1864,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, if (sp >= 2) { jsize index = static_cast(stack[--sp]); jcharArray arr = reinterpret_cast(stack[--sp]); - jchar val; - env->GetCharArrayRegion(arr, index, 1, &val); - stack[sp++] = val; + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jchar* elems = array_cache.get(arr, false, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + stack[sp++] = elems[index]; + } + } + } } goto dispatch; @@ -1645,9 +1884,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, if (sp >= 2) { jsize index = static_cast(stack[--sp]); jshortArray arr = reinterpret_cast(stack[--sp]); - jshort val; - env->GetShortArrayRegion(arr, index, 1, &val); - stack[sp++] = val; + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jshort* elems = array_cache.get(arr, false, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + stack[sp++] = elems[index]; + } + } + } } goto dispatch; @@ -1656,7 +1905,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, jint value = static_cast(stack[--sp]); jsize index = static_cast(stack[--sp]); jintArray arr = reinterpret_cast(stack[--sp]); - env->SetIntArrayRegion(arr, index, 1, &value); + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jint* elems = array_cache.get(arr, true, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + elems[index] = value; + } + } + } } goto dispatch; @@ -1665,7 +1926,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, jlong value = static_cast(stack[--sp]); jsize index = static_cast(stack[--sp]); jlongArray arr = reinterpret_cast(stack[--sp]); - env->SetLongArrayRegion(arr, index, 1, &value); + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jlong* elems = array_cache.get(arr, true, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + elems[index] = value; + } + } + } } goto dispatch; @@ -1677,7 +1950,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, std::memcpy(&value, &bits, sizeof(float)); jsize index = static_cast(stack[--sp]); jfloatArray arr = reinterpret_cast(stack[--sp]); - env->SetFloatArrayRegion(arr, index, 1, &value); + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jfloat* elems = array_cache.get(arr, true, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + elems[index] = value; + } + } + } } goto dispatch; @@ -1689,7 +1974,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, std::memcpy(&value, &bits, sizeof(double)); jsize index = static_cast(stack[--sp]); jdoubleArray arr = reinterpret_cast(stack[--sp]); - env->SetDoubleArrayRegion(arr, index, 1, &value); + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jdouble* elems = array_cache.get(arr, true, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + elems[index] = value; + } + } + } } goto dispatch; @@ -1698,7 +1995,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, jbyte value = static_cast(stack[--sp]); jsize index = static_cast(stack[--sp]); jbyteArray arr = reinterpret_cast(stack[--sp]); - env->SetByteArrayRegion(arr, index, 1, &value); + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jbyte* elems = array_cache.get(arr, true, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + elems[index] = value; + } + } + } } goto dispatch; @@ -1707,7 +2016,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, jchar value = static_cast(stack[--sp]); jsize index = static_cast(stack[--sp]); jcharArray arr = reinterpret_cast(stack[--sp]); - env->SetCharArrayRegion(arr, index, 1, &value); + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jchar* elems = array_cache.get(arr, true, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + elems[index] = value; + } + } + } } goto dispatch; @@ -1716,7 +2037,19 @@ int64_t execute(JNIEnv* env, const Instruction* code, size_t length, jshort value = static_cast(stack[--sp]); jsize index = static_cast(stack[--sp]); jshortArray arr = reinterpret_cast(stack[--sp]); - env->SetShortArrayRegion(arr, index, 1, &value); + if (arr == nullptr) { + throw_null_array(env); + } else { + PrimitiveArrayCacheEntry* entry = nullptr; + jshort* elems = array_cache.get(arr, true, &entry); + if (elems != nullptr && entry != nullptr) { + if (index < 0 || index >= entry->length) { + throw_array_index_oob(env, index, entry->length); + } else { + elems[index] = value; + } + } + } } goto dispatch;