diff --git a/langchain/embeddings/llamacpp.py b/langchain/embeddings/llamacpp.py index 44c887a8ddcc6..52500e9373847 100644 --- a/langchain/embeddings/llamacpp.py +++ b/langchain/embeddings/llamacpp.py @@ -53,6 +53,9 @@ class LlamaCppEmbeddings(BaseModel, Embeddings): """Number of tokens to process in parallel. Should be a number between 1 and n_ctx.""" + n_gpu_layers: Optional[int] = None + """Number of layers to store in VRAM.""" + class Config: """Configuration for this pydantic object.""" @@ -71,6 +74,7 @@ def validate_environment(cls, values: Dict) -> Dict: use_mlock = values["use_mlock"] n_threads = values["n_threads"] n_batch = values["n_batch"] + n_gpu_layers = values["n_gpu_layers"] try: from llama_cpp import Llama @@ -86,6 +90,7 @@ def validate_environment(cls, values: Dict) -> Dict: use_mlock=use_mlock, n_threads=n_threads, n_batch=n_batch, + n_gpu_layers=n_gpu_layers, embedding=True, ) except ImportError: diff --git a/langchain/llms/llamacpp.py b/langchain/llms/llamacpp.py index 6a10af9dccf5c..43234597326d5 100644 --- a/langchain/llms/llamacpp.py +++ b/langchain/llms/llamacpp.py @@ -100,6 +100,9 @@ class LlamaCpp(LLM): streaming: bool = True """Whether to stream the results, token by token.""" + n_gpu_layers: Optional[int] = None + """Number of layers to store in VRAM.""" + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that llama-cpp-python library is installed.""" @@ -117,6 +120,7 @@ def validate_environment(cls, values: Dict) -> Dict: n_batch = values["n_batch"] use_mmap = values["use_mmap"] last_n_tokens_size = values["last_n_tokens_size"] + n_gpu_layers = values["n_gpu_layers"] try: from llama_cpp import Llama @@ -136,6 +140,7 @@ def validate_environment(cls, values: Dict) -> Dict: n_batch=n_batch, use_mmap=use_mmap, last_n_tokens_size=last_n_tokens_size, + n_gpu_layers=n_gpu_layers, ) except ImportError: raise ModuleNotFoundError(