Skip to content
Merged
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ resolver = "2"
members = [
"llama-cpp-sys-2",
"llama-cpp-2",
"simple",
"simple", "embeddings",
]

[workspace.dependencies]
Expand Down
15 changes: 15 additions & 0 deletions embeddings/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "embeddings"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
llama-cpp-2 = { path = "../llama-cpp-2", version = "0.1.34" }
hf-hub = { workspace = true }
clap = { workspace = true , features = ["derive"] }
anyhow = { workspace = true }

[lints]
workspace = true
215 changes: 215 additions & 0 deletions embeddings/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
//! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2.
#![allow(
clippy::cast_possible_wrap,
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::cast_sign_loss
)]

use std::io::Write;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;

use anyhow::{bail, Context, Result};
use clap::Parser;
use hf_hub::api::sync::ApiBuilder;

use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::ggml_time_us;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::AddBos;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::params::LlamaModelParams;

#[derive(clap::Parser, Debug, Clone)]
struct Args {
/// The path to the model
#[command(subcommand)]
model: Model,
/// The prompt
#[clap(default_value = "Hello my name is")]
prompt: String,
/// Whether to normalise the produced embeddings
#[clap(short)]
normalise: bool,
/// Disable offloading layers to the gpu
#[cfg(feature = "cublas")]
#[clap(long)]
disable_gpu: bool,
}


#[derive(clap::Subcommand, Debug, Clone)]
enum Model {
/// Use an already downloaded model
Local {
/// The path to the model. e.g. `/home/marcus/.cache/huggingface/hub/models--TheBloke--Llama-2-7B-Chat-GGUF/blobs/08a5566d61d7cb6b420c3e4387a39e0078e1f2fe5f055f3a03887385304d4bfa`
path: PathBuf,
},
/// Download a model from huggingface (or use a cached version)
#[clap(name = "hf-model")]
HuggingFace {
/// the repo containing the model. e.g. `BAAI/bge-small-en-v1.5`
repo: String,
/// the model name. e.g. `BAAI-bge-small-v1.5.Q4_K_M.gguf`
model: String,
},
}

impl Model {
/// Convert the model to a path - may download from huggingface
fn get_or_load(self) -> Result<PathBuf> {
match self {
Model::Local { path } => Ok(path),
Model::HuggingFace { model, repo } => ApiBuilder::new()
.with_progress(true)
.build()
.with_context(|| "unable to create huggingface api")?
.model(repo)
.get(&model)
.with_context(|| "unable to download model"),
}
}
}

fn main() -> Result<()> {
let Args {
model,
prompt,
normalise,
#[cfg(feature = "cublas")]
disable_gpu,
} = Args::parse();

// init LLM
let backend = LlamaBackend::init()?;

// offload all layers to the gpu
let model_params = {
#[cfg(feature = "cublas")]
if !disable_gpu {
LlamaModelParams::default().with_n_gpu_layers(1000)
} else {
LlamaModelParams::default()
}
#[cfg(not(feature = "cublas"))]
LlamaModelParams::default()
};

let model_path = model
.get_or_load()
.with_context(|| "failed to get model from args")?;

let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
.with_context(|| "unable to load model")?;

// initialize the context
let ctx_params = LlamaContextParams::default()
.with_n_threads_batch(std::thread::available_parallelism()?.get() as u32)
.with_embeddings(true);

let mut ctx = model
.new_context(&backend, ctx_params)
.with_context(|| "unable to create the llama_context")?;

// Split the prompt to display the batching functionality
let prompt_lines = prompt.lines();

// tokenize the prompt
let tokens_lines_list = prompt_lines.map(|line| model.str_to_token(&line, AddBos::Always))
.collect::<Result<Vec<_>, _>>()
.with_context(|| format!("failed to tokenize {prompt}"))?;

let n_ctx = ctx.n_ctx() as usize;
let n_ctx_train = model.n_ctx_train();

eprintln!("n_ctx = {n_ctx}, n_ctx_train = {n_ctx_train}");

if tokens_lines_list.iter().any(|tok| n_ctx < tok.len()) {
bail!("One of the provided prompts exceeds the size of the context window");
}

// print the prompt token-by-token
eprintln!();

for (i, token_line) in tokens_lines_list.iter().enumerate() {
eprintln!("Prompt {i}");
for token in token_line {
eprintln!(" {} --> {}", token, model.token_to_str(*token)?);
}
eprintln!()
}

std::io::stderr().flush()?;

// create a llama_batch with the size of the context
// we use this object to submit token data for decoding
let mut batch = LlamaBatch::new(n_ctx, 1);

let mut max_seq_id_batch = 0;
let mut output = Vec::with_capacity(tokens_lines_list.len());

let t_main_start = ggml_time_us();

for tokens in &tokens_lines_list {
// Flush the batch if the next prompt would exceed our batch size
if (batch.n_tokens() as usize + tokens.len()) > n_ctx {
batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?;
max_seq_id_batch = 0;
}

batch.add_sequence(&tokens, max_seq_id_batch, false)?;
max_seq_id_batch += 1;
}
// Handle final batch
batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?;

let t_main_end = ggml_time_us();

for (i, embeddings) in output.iter().enumerate() {
eprintln!("Embeddings {i}: {embeddings:?}");
eprintln!();
}

let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
let total_tokens: usize = tokens_lines_list.iter().map(|v| v.len()).sum();
eprintln!(
"Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n",
total_tokens,
duration.as_secs_f32(),
total_tokens as f32 / duration.as_secs_f32()
);

println!("{}", ctx.timings());

Ok(())
}

fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, output: &mut Vec<Vec<f32>>, normalise: bool) -> Result<()> {
ctx.clear_kv_cache();
ctx.decode(batch).with_context(|| "llama_decode() failed")?;

for i in 0..s_batch {
let embedding = ctx.embeddings_seq_ith(i).with_context(|| "Failed to get embeddings")?;
let output_embeddings = if normalise {
normalize(embedding)
} else {
embedding.to_vec()
};

output.push(output_embeddings);
}

batch.clear();

Ok(())
}

fn normalize(input: &[f32]) -> Vec<f32> {
let magnitude = input.iter().fold(0.0, |acc, &val| val.mul_add(val, acc)).sqrt();

input.iter().map(|&val| val / magnitude).collect()
}
66 changes: 63 additions & 3 deletions llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

use std::fmt::{Debug, Formatter};
use std::num::NonZeroI32;
use std::ptr::NonNull;
use std::slice;

use crate::llama_batch::LlamaBatch;
use crate::model::LlamaModel;
use crate::timing::LlamaTimings;
use crate::token::data::LlamaTokenData;
use crate::token::LlamaToken;
use crate::DecodeError;
use std::ptr::NonNull;
use std::slice;
use crate::{DecodeError, EmbeddingsError};

pub mod kv_cache;
pub mod params;
Expand All @@ -24,6 +24,7 @@ pub struct LlamaContext<'a> {
/// a reference to the contexts model.
pub model: &'a LlamaModel,
initialized_logits: Vec<i32>,
embeddings_enabled: bool,
}

impl Debug for LlamaContext<'_> {
Expand All @@ -38,11 +39,13 @@ impl<'model> LlamaContext<'model> {
pub(crate) fn new(
llama_model: &'model LlamaModel,
llama_context: NonNull<llama_cpp_sys_2::llama_context>,
embeddings_enabled: bool,
) -> Self {
Self {
context: llama_context,
model: llama_model,
initialized_logits: Vec::new(),
embeddings_enabled,
}
}

Expand Down Expand Up @@ -80,6 +83,63 @@ impl<'model> LlamaContext<'model> {
}
}

/// Get the embeddings for the `i`th sequence in the current context.
///
/// # Returns
///
/// A slice containing the embeddings for the last decoded batch.
/// The size corresponds to the `n_embd` parameter of the context's model.
///
/// # Errors
///
/// - When the current context was constructed without enabling embeddings.
/// - If the current model had a pooling type of [`llama_cpp_sys_2::LLAMA_POOLING_TYPE_NONE`]
/// - If the given sequence index exceeds the max sequence id.
pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
if !self.embeddings_enabled {
return Err(EmbeddingsError::NotEnabled);
}

unsafe {
let embedding = llama_cpp_sys_2::llama_get_embeddings_seq(self.context.as_ptr(), i);

// Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
if embedding.is_null() {
Err(EmbeddingsError::NonePoolType)
} else {
Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize))
}
}
}

/// Get the embeddings for the `i`th token in the current context.
///
/// # Returns
///
/// A slice containing the embeddings for the last decoded batch of the given token.
/// The size corresponds to the `n_embd` parameter of the context's model.
///
/// # Errors
///
/// - When the current context was constructed without enabling embeddings.
/// - When the given token didn't have logits enabled when it was passed.
/// - If the given token index exceeds the max token id.
pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
if !self.embeddings_enabled {
return Err(EmbeddingsError::NotEnabled);
}

unsafe {
let embedding = llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i);
// Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
if embedding.is_null() {
Err(EmbeddingsError::LogitsNotEnabled)
} else {
Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize))
}
}
}

/// Get the logits for the ith token in the context.
///
/// # Panics
Expand Down
Loading