diff --git a/extension/android/BUCK b/extension/android/BUCK index 040c9258d42..ac85160a38c 100644 --- a/extension/android/BUCK +++ b/extension/android/BUCK @@ -25,6 +25,8 @@ fb_android_library( srcs = [ "src/main/java/org/pytorch/executorch/LlamaCallback.java", "src/main/java/org/pytorch/executorch/LlamaModule.java", + "src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java", + "src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java", ], autoglob = False, language = "JAVA", diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 551307b495a..f3c62e1d70f 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -408,14 +408,14 @@ class ExecuTorchJni : public facebook::jni::HybridClass { } // namespace executorch::extension #ifdef EXECUTORCH_BUILD_LLAMA_JNI -extern void register_natives_for_llama(); +extern void register_natives_for_llm(); #else -// No op if we don't build llama -void register_natives_for_llama() {} +// No op if we don't build LLM +void register_natives_for_llm() {} #endif JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize(vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); - register_natives_for_llama(); + register_natives_for_llm(); }); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index e71323da201..d6ade74ee1f 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -75,14 +75,14 @@ std::string token_buffer; namespace executorch_jni { -class ExecuTorchLlamaCallbackJni - : public facebook::jni::JavaClass { +class ExecuTorchLlmCallbackJni + : public facebook::jni::JavaClass { public: constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/LlamaCallback;"; + "Lorg/pytorch/executorch/extension/llm/LlmCallback;"; void onResult(std::string result) const { - static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic(); + static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); static const auto method = cls->getMethod)>("onResult"); @@ -99,7 +99,7 @@ class ExecuTorchLlamaCallbackJni } void onStats(const llm::Stats& result) const { - static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic(); + static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); static const auto method = cls->getMethod("onStats"); double eval_time = (double)(result.inference_end_ms - result.prompt_eval_end_ms); @@ -111,8 +111,7 @@ class ExecuTorchLlamaCallbackJni } }; -class ExecuTorchLlamaJni - : public facebook::jni::HybridClass { +class ExecuTorchLlmJni : public facebook::jni::HybridClass { private: friend HybridBase; int model_type_category_; @@ -121,7 +120,7 @@ class ExecuTorchLlamaJni public: constexpr static auto kJavaDescriptor = - "Lorg/pytorch/executorch/LlamaModule;"; + "Lorg/pytorch/executorch/extension/llm/LlmModule;"; constexpr static int MODEL_TYPE_CATEGORY_LLM = 1; constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2; @@ -142,7 +141,7 @@ class ExecuTorchLlamaJni data_path); } - ExecuTorchLlamaJni( + ExecuTorchLlmJni( jint model_type_category, facebook::jni::alias_ref model_path, facebook::jni::alias_ref tokenizer_path, @@ -197,7 +196,7 @@ class ExecuTorchLlamaJni jint channels, facebook::jni::alias_ref prompt, jint seq_len, - facebook::jni::alias_ref callback, + facebook::jni::alias_ref callback, jboolean echo) { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { auto image_size = image->size(); @@ -296,7 +295,7 @@ class ExecuTorchLlamaJni facebook::jni::alias_ref prompt, jint seq_len, jlong start_pos, - facebook::jni::alias_ref callback, + facebook::jni::alias_ref callback, jboolean echo) { if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { return static_cast(Error::NotSupported); @@ -329,22 +328,22 @@ class ExecuTorchLlamaJni static void registerNatives() { registerHybrid({ - makeNativeMethod("initHybrid", ExecuTorchLlamaJni::initHybrid), - makeNativeMethod("generate", ExecuTorchLlamaJni::generate), - makeNativeMethod("stop", ExecuTorchLlamaJni::stop), - makeNativeMethod("load", ExecuTorchLlamaJni::load), + makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid), + makeNativeMethod("generate", ExecuTorchLlmJni::generate), + makeNativeMethod("stop", ExecuTorchLlmJni::stop), + makeNativeMethod("load", ExecuTorchLlmJni::load), makeNativeMethod( - "prefillImagesNative", ExecuTorchLlamaJni::prefill_images), + "prefillImagesNative", ExecuTorchLlmJni::prefill_images), makeNativeMethod( - "prefillPromptNative", ExecuTorchLlamaJni::prefill_prompt), + "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt), makeNativeMethod( - "generateFromPos", ExecuTorchLlamaJni::generate_from_pos), + "generateFromPos", ExecuTorchLlmJni::generate_from_pos), }); } }; } // namespace executorch_jni -void register_natives_for_llama() { - executorch_jni::ExecuTorchLlamaJni::registerNatives(); +void register_natives_for_llm() { + executorch_jni::ExecuTorchLlmJni::registerNatives(); } diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java index b30fa2515a9..33421f26f0f 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java @@ -9,15 +9,14 @@ package org.pytorch.executorch; import com.facebook.jni.annotations.DoNotStrip; -import org.pytorch.executorch.annotations.Experimental; /** * Callback interface for Llama model. Users can implement this interface to receive the generated * tokens and statistics. * - *

Warning: These APIs are experimental and subject to change without notice + *

Note: deprecated! Please use {@link org.pytorch.executorch.extension.llm.LlmCallback} instead. */ -@Experimental +@Deprecated public interface LlamaCallback { /** * Called when a new result is available from JNI. Users will keep getting onResult() invocations diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index d64c7d0bf89..6a201fb56ea 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -8,42 +8,28 @@ package org.pytorch.executorch; -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; -import com.facebook.soloader.nativeloader.NativeLoader; -import com.facebook.soloader.nativeloader.SystemDelegate; -import org.pytorch.executorch.annotations.Experimental; +import org.pytorch.executorch.extension.llm.LlmCallback; +import org.pytorch.executorch.extension.llm.LlmModule; /** * LlamaModule is a wrapper around the Executorch Llama model. It provides a simple interface to * generate text from the model. * - *

Warning: These APIs are experimental and subject to change without notice + *

Note: deprecated! Please use {@link org.pytorch.executorch.extension.llm.LlmModule} instead. */ -@Experimental +@Deprecated public class LlamaModule { public static final int MODEL_TYPE_TEXT = 1; public static final int MODEL_TYPE_TEXT_VISION = 2; - static { - if (!NativeLoader.isInitialized()) { - NativeLoader.init(new SystemDelegate()); - } - NativeLoader.loadLibrary("executorch"); - } - - private final HybridData mHybridData; + private LlmModule mModule; private static final int DEFAULT_SEQ_LEN = 128; private static final boolean DEFAULT_ECHO = true; - @DoNotStrip - private static native HybridData initHybrid( - int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath); - /** Constructs a LLAMA Module for a model with given model path, tokenizer, temperature. */ public LlamaModule(String modulePath, String tokenizerPath, float temperature) { - mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null); + mModule = new LlmModule(modulePath, tokenizerPath, temperature); } /** @@ -51,16 +37,16 @@ public LlamaModule(String modulePath, String tokenizerPath, float temperature) { * path. */ public LlamaModule(String modulePath, String tokenizerPath, float temperature, String dataPath) { - mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath); + mModule = new LlmModule(modulePath, tokenizerPath, temperature, dataPath); } /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ public LlamaModule(int modelType, String modulePath, String tokenizerPath, float temperature) { - mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null); + mModule = new LlmModule(modelType, modulePath, tokenizerPath, temperature); } public void resetNative() { - mHybridData.resetNative(); + mModule.resetNative(); } /** @@ -70,7 +56,7 @@ public void resetNative() { * @param llamaCallback callback object to receive results. */ public int generate(String prompt, LlamaCallback llamaCallback) { - return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO); + return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO); } /** @@ -119,8 +105,7 @@ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback, bool * @param llamaCallback callback object to receive results. * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - @DoNotStrip - public native int generate( + public int generate( int[] image, int width, int height, @@ -128,7 +113,27 @@ public native int generate( String prompt, int seqLen, LlamaCallback llamaCallback, - boolean echo); + boolean echo) { + return mModule.generate( + image, + width, + height, + channels, + prompt, + seqLen, + new LlmCallback() { + @Override + public void onResult(String result) { + llamaCallback.onResult(result); + } + + @Override + public void onStats(float tps) { + llamaCallback.onStats(tps); + } + }, + echo); + } /** * Prefill an LLaVA Module with the given images input. @@ -142,17 +147,9 @@ public native int generate( * @throws RuntimeException if the prefill failed */ public long prefillImages(int[] image, int width, int height, int channels, long startPos) { - long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos); - if (nativeResult[0] != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); - } - return nativeResult[1]; + return mModule.prefillImages(image, width, height, channels, startPos); } - // returns a tuple of (status, updated startPos) - private native long[] prefillImagesNative( - int[] image, int width, int height, int channels, long startPos); - /** * Prefill an LLaVA Module with the given text input. * @@ -165,16 +162,9 @@ private native long[] prefillImagesNative( * @throws RuntimeException if the prefill failed */ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { - long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos); - if (nativeResult[0] != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); - } - return nativeResult[1]; + return mModule.prefillPrompt(prompt, startPos, bos, eos); } - // returns a tuple of (status, updated startPos) - private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos); - /** * Generate tokens from the given prompt, starting from the given position. * @@ -185,14 +175,33 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @param echo indicate whether to echo the input prompt or not. * @return The error code. */ - public native int generateFromPos( - String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo); + public int generateFromPos( + String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo) { + return mModule.generateFromPos( + prompt, + seqLen, + startPos, + new LlmCallback() { + @Override + public void onResult(String result) { + callback.onResult(result); + } + + @Override + public void onStats(float tps) { + callback.onStats(tps); + } + }, + echo); + } /** Stop current generate() before it finishes. */ - @DoNotStrip - public native void stop(); + public void stop() { + mModule.stop(); + } /** Force loading the module. Otherwise the model is loaded during first generate(). */ - @DoNotStrip - public native int load(); + public int load() { + return mModule.load(); + } } diff --git a/extension/android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java new file mode 100644 index 00000000000..c05b30b0625 --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm; + +import com.facebook.jni.annotations.DoNotStrip; +import org.pytorch.executorch.annotations.Experimental; + +/** + * Callback interface for Llama model. Users can implement this interface to receive the generated + * tokens and statistics. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public interface LlmCallback { + /** + * Called when a new result is available from JNI. Users will keep getting onResult() invocations + * until generate() finishes. + * + * @param result Last generated token + */ + @DoNotStrip + public void onResult(String result); + + /** + * Called when the statistics for the generate() is available. + * + * @param tps Tokens/second for generated tokens. + */ + @DoNotStrip + public void onStats(float tps); +} diff --git a/extension/android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java new file mode 100644 index 00000000000..8262d7cfdad --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -0,0 +1,198 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm; + +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; +import org.pytorch.executorch.annotations.Experimental; + +/** + * LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to + * generate text from the model. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public class LlmModule { + + public static final int MODEL_TYPE_TEXT = 1; + public static final int MODEL_TYPE_TEXT_VISION = 2; + + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + NativeLoader.loadLibrary("executorch"); + } + + private final HybridData mHybridData; + private static final int DEFAULT_SEQ_LEN = 128; + private static final boolean DEFAULT_ECHO = true; + + @DoNotStrip + private static native HybridData initHybrid( + int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath); + + /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */ + public LlmModule(String modulePath, String tokenizerPath, float temperature) { + mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null); + } + + /** + * Constructs a LLM Module for a model with given model path, tokenizer, temperature and data + * path. + */ + public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) { + mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath); + } + + /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ + public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) { + mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null); + } + + public void resetNative() { + mHybridData.resetNative(); + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param llmCallback callback object to receive results. + */ + public int generate(String prompt, LlmCallback llmCallback) { + return generate(prompt, DEFAULT_SEQ_LEN, llmCallback, DEFAULT_ECHO); + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + */ + public int generate(String prompt, int seqLen, LlmCallback llmCallback) { + return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, DEFAULT_ECHO); + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + public int generate(String prompt, LlmCallback llmCallback, boolean echo) { + return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, llmCallback, echo); + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { + return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo); + } + + /** + * Start generating tokens from the module. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + @DoNotStrip + public native int generate( + int[] image, + int width, + int height, + int channels, + String prompt, + int seqLen, + LlmCallback llmCallback, + boolean echo); + + /** + * Prefill an LLaVA Module with the given images input. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param startPos The starting position in KV cache of the input in the LLM. + * @return The updated starting position in KV cache of the input in the LLM. + * @throws RuntimeException if the prefill failed + */ + public long prefillImages(int[] image, int width, int height, int channels, long startPos) { + long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos); + if (nativeResult[0] != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + } + return nativeResult[1]; + } + + // returns a tuple of (status, updated startPos) + private native long[] prefillImagesNative( + int[] image, int width, int height, int channels, long startPos); + + /** + * Prefill an LLaVA Module with the given text input. + * + * @param prompt The text prompt to LLaVA. + * @param startPos The starting position in KV cache of the input in the LLM. It's passed as + * reference and will be updated inside this function. + * @param bos The number of BOS (begin of sequence) token. + * @param eos The number of EOS (end of sequence) token. + * @return The updated starting position in KV cache of the input in the LLM. + * @throws RuntimeException if the prefill failed + */ + public long prefillPrompt(String prompt, long startPos, int bos, int eos) { + long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos); + if (nativeResult[0] != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + } + return nativeResult[1]; + } + + // returns a tuple of (status, updated startPos) + private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos); + + /** + * Generate tokens from the given prompt, starting from the given position. + * + * @param prompt The text prompt to LLaVA. + * @param seqLen The total sequence length, including the prompt tokens and new tokens. + * @param startPos The starting position in KV cache of the input in the LLM. + * @param callback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not. + * @return The error code. + */ + public native int generateFromPos( + String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo); + + /** Stop current generate() before it finishes. */ + @DoNotStrip + public native void stop(); + + /** Force loading the module. Otherwise the model is loaded during first generate(). */ + @DoNotStrip + public native int load(); +}