Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,28 @@ pub enum StringToTokenError {
CIntConversionError(#[from] std::num::TryFromIntError),
}

/// Failed to apply model chat template.
#[derive(Debug, thiserror::Error)]
pub enum NewLlamaChatMessageError {
/// the string contained a null byte and thus could not be converted to a c string.
#[error("{0}")]
NulError(#[from] NulError),
}

/// Failed to apply model chat template.
#[derive(Debug, thiserror::Error)]
pub enum ApplyChatTemplateError {
/// the buffer was too small.
#[error("The buffer was too small. Please contact a maintainer and we will update it.")]
BuffSizeError,
/// the string contained a null byte and thus could not be converted to a c string.
#[error("{0}")]
NulError(#[from] NulError),
/// the string could not be converted to utf8.
#[error("{0}")]
FromUtf8Error(#[from] FromUtf8Error),
}

/// Get the time in microseconds according to ggml
///
/// ```
Expand Down
82 changes: 76 additions & 6 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use crate::model::params::LlamaModelParams;
use crate::token::LlamaToken;
use crate::token_type::LlamaTokenType;
use crate::{
ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError, StringToTokenError,
TokenToStringError,
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError,
NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
};

pub mod params;
Expand All @@ -25,6 +25,23 @@ pub struct LlamaModel {
pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
}

/// A Safe wrapper around `llama_chat_message`
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct LlamaChatMessage {
role: CString,
content: CString,
}

impl LlamaChatMessage {
/// Create a new `LlamaChatMessage`
pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
Ok(Self {
role: CString::new(role)?,
content: CString::new(content)?,
})
}
}

/// How to determine if we should prepend a bos token to tokens
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddBos {
Expand Down Expand Up @@ -312,17 +329,16 @@ impl LlamaModel {
/// Get chat template from model.
///
/// # Errors
///
///
/// * If the model has no chat template
/// * If the chat template is not a valid [`CString`].
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {

// longest known template is about 1200 bytes from llama.cpp
let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
let chat_ptr = chat_temp.into_raw();
let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");

let chat_template: String = unsafe {
let ret = llama_cpp_sys_2::llama_model_meta_val_str(
self.model.as_ptr(),
Expand All @@ -337,7 +353,7 @@ impl LlamaModel {
debug_assert_eq!(usize::try_from(ret).unwrap(), template.len(), "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case", template.len());
template
};

Ok(chat_template)
}

Expand Down Expand Up @@ -388,6 +404,60 @@ impl LlamaModel {

Ok(LlamaContext::new(self, context, params.embeddings()))
}

/// Apply the models chat template to some messages.
/// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
///
/// `tmpl` of None means to use the default template provided by llama.cpp for the model
///
/// # Errors
/// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
#[tracing::instrument(skip_all)]
pub fn apply_chat_template(
&self,
tmpl: Option<String>,
chat: Vec<LlamaChatMessage>,
add_ass: bool,
) -> Result<String, ApplyChatTemplateError> {
// Buffer is twice the length of messages per their recommendation
let message_length = chat.iter().fold(0, |acc, c| {
acc + c.role.to_bytes().len() + c.content.to_bytes().len()
});
let mut buff: Vec<i8> = vec![0_i8; message_length * 2];

// Build our llama_cpp_sys_2 chat messages
let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
.iter()
.map(|c| llama_cpp_sys_2::llama_chat_message {
role: c.role.as_ptr(),
content: c.content.as_ptr(),
})
.collect();
// Set the tmpl pointer
let tmpl = tmpl.map(CString::new);
let tmpl_ptr = match tmpl {
Some(str) => str?.as_ptr(),
None => std::ptr::null(),
};
let formatted_chat = unsafe {
let res = llama_cpp_sys_2::llama_chat_apply_template(
self.model.as_ptr(),
tmpl_ptr,
chat.as_ptr(),
chat.len(),
add_ass,
buff.as_mut_ptr().cast::<std::os::raw::c_char>(),
buff.len() as i32,
);
// A buffer twice the size should be sufficient for all models, if this is not the case for a new model, we can increase it
// The error message informs the user to contact a maintainer
if res > buff.len() as i32 {
return Err(ApplyChatTemplateError::BuffSizeError);
}
String::from_utf8(buff.iter().filter(|c| **c > 0).map(|&c| c as u8).collect())
}?;
Ok(formatted_chat)
}
}

impl Drop for LlamaModel {
Expand Down