diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index 4b4236f73..c4f96c37f 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -214,7 +214,7 @@ public sealed record MetadataOverride
///
/// Get the key being overriden by this override
///
- public string Key { get; init; }
+ public string Key { get; }
internal LLamaModelKvOverrideType Type { get; }
diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs
index 3f038874f..dc7414490 100644
--- a/LLama/Common/ChatHistory.cs
+++ b/LLama/Common/ChatHistory.cs
@@ -1,5 +1,4 @@
using System.Collections.Generic;
-using System.IO;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -37,6 +36,7 @@ public enum AuthorRole
///
public class ChatHistory
{
+ private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true };
///
/// Chat message representation
@@ -96,12 +96,7 @@ public void AddMessage(AuthorRole authorRole, string content)
///
public string ToJson()
{
- return JsonSerializer.Serialize(
- this,
- new JsonSerializerOptions()
- {
- WriteIndented = true
- });
+ return JsonSerializer.Serialize(this, _jsonOptions);
}
///
diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs
index 6d272f23f..8c14a1961 100644
--- a/LLama/Common/FixedSizeQueue.cs
+++ b/LLama/Common/FixedSizeQueue.cs
@@ -2,7 +2,6 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
-using LLama.Extensions;
namespace LLama.Common
{
diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs
index c1f395505..0e6020ad4 100644
--- a/LLama/Common/InferenceParams.cs
+++ b/LLama/Common/InferenceParams.cs
@@ -18,11 +18,13 @@ public record InferenceParams
/// number of tokens to keep from initial prompt
///
public int TokensKeep { get; set; } = 0;
+
///
/// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
/// until it complete.
///
public int MaxTokens { get; set; } = -1;
+
///
/// logit bias for specific tokens
///
diff --git a/LLama/Extensions/DictionaryExtensions.cs b/LLama/Extensions/DictionaryExtensions.cs
index b3643fae1..6599f6316 100644
--- a/LLama/Extensions/DictionaryExtensions.cs
+++ b/LLama/Extensions/DictionaryExtensions.cs
@@ -15,6 +15,7 @@ public static TValue GetValueOrDefault(this IReadOnlyDictionary(IReadOnlyDictionary dictionary, TKey key, TValue defaultValue)
{
+ // ReSharper disable once CanSimplifyDictionaryTryGetValueWithGetValueOrDefault (this is a shim for that method!)
return dictionary.TryGetValue(key, out var value) ? value : defaultValue;
}
}
diff --git a/LLama/Grammars/Grammar.cs b/LLama/Grammars/Grammar.cs
index 5135e341e..abb65aa30 100644
--- a/LLama/Grammars/Grammar.cs
+++ b/LLama/Grammars/Grammar.cs
@@ -15,7 +15,7 @@ public sealed class Grammar
///
/// Index of the initial rule to start from
///
- public ulong StartRuleIndex { get; set; }
+ public ulong StartRuleIndex { get; }
///
/// The rules which make up this grammar
@@ -121,6 +121,12 @@ private void PrintRule(StringBuilder output, GrammarRule rule)
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
break;
+
+ case LLamaGrammarElementType.END:
+ case LLamaGrammarElementType.ALT:
+ case LLamaGrammarElementType.RULE_REF:
+ case LLamaGrammarElementType.CHAR:
+ case LLamaGrammarElementType.CHAR_NOT:
default:
output.Append("] ");
break;
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index db0ac179c..abd8f879f 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -43,7 +43,7 @@ public sealed class LLamaContext
///
/// The context params set for this context
///
- public IContextParams Params { get; set; }
+ public IContextParams Params { get; }
///
/// The native handle, which is used to be passed to the native APIs
@@ -56,15 +56,6 @@ public sealed class LLamaContext
///
public Encoding Encoding { get; }
- internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
- {
- Params = @params;
-
- _logger = logger;
- Encoding = @params.Encoding;
- NativeHandle = nativeContext;
- }
-
///
/// Create a new LLamaContext for the given LLamaWeights
///
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index ab56280c3..c551016c0 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -12,17 +12,15 @@ namespace LLama
public sealed class LLamaEmbedder
: IDisposable
{
- private readonly LLamaContext _ctx;
-
///
/// Dimension of embedding vectors
///
- public int EmbeddingSize => _ctx.EmbeddingSize;
+ public int EmbeddingSize => Context.EmbeddingSize;
///
/// LLama Context
///
- public LLamaContext Context => this._ctx;
+ public LLamaContext Context { get; }
///
/// Create a new embedder, using the given LLamaWeights
@@ -33,7 +31,7 @@ public sealed class LLamaEmbedder
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
@params.EmbeddingMode = true;
- _ctx = weights.CreateContext(@params, logger);
+ Context = weights.CreateContext(@params, logger);
}
///
@@ -72,20 +70,20 @@ public float[] GetEmbeddings(string text)
///
public float[] GetEmbeddings(string text, bool addBos)
{
- var embed_inp_array = _ctx.Tokenize(text, addBos);
+ var embed_inp_array = Context.Tokenize(text, addBos);
// TODO(Rinne): deal with log of prompt
if (embed_inp_array.Length > 0)
- _ctx.Eval(embed_inp_array, 0);
+ Context.Eval(embed_inp_array, 0);
unsafe
{
- var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle);
+ var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty();
- return new Span(embeddings, EmbeddingSize).ToArray();
+ return embeddings.ToArray();
}
}
@@ -94,7 +92,7 @@ public float[] GetEmbeddings(string text, bool addBos)
///
public void Dispose()
{
- _ctx.Dispose();
+ Context.Dispose();
}
}
diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs
index 847be5515..5cb482add 100644
--- a/LLama/LLamaWeights.cs
+++ b/LLama/LLamaWeights.cs
@@ -64,7 +64,7 @@ public sealed class LLamaWeights
///
public IReadOnlyDictionary Metadata { get; set; }
- internal LLamaWeights(SafeLlamaModelHandle weights)
+ private LLamaWeights(SafeLlamaModelHandle weights)
{
NativeHandle = weights;
Metadata = weights.ReadMetadata();
diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs
index ea1c1172c..65fbccba3 100644
--- a/LLama/Native/LLamaKvCacheView.cs
+++ b/LLama/Native/LLamaKvCacheView.cs
@@ -14,7 +14,7 @@ public struct LLamaKvCacheViewCell
/// May be negative if the cell is not populated.
///
public LLamaPos pos;
-};
+}
///
/// An updateable view of the KV cache (llama_kv_cache_view)
@@ -130,7 +130,7 @@ public ref LLamaKvCacheView GetView()
}
}
-partial class NativeApi
+public static partial class NativeApi
{
///
/// Create an empty KV cache view. (use only for debugging purposes)
diff --git a/LLama/Native/LLamaPos.cs b/LLama/Native/LLamaPos.cs
index 67ede7d52..52d67d505 100644
--- a/LLama/Native/LLamaPos.cs
+++ b/LLama/Native/LLamaPos.cs
@@ -6,7 +6,7 @@ namespace LLama.Native;
/// Indicates position in a sequence
///
[StructLayout(LayoutKind.Sequential)]
-public struct LLamaPos
+public record struct LLamaPos
{
///
/// The raw value
@@ -17,7 +17,7 @@ public struct LLamaPos
/// Create a new LLamaPos
///
///
- public LLamaPos(int value)
+ private LLamaPos(int value)
{
Value = value;
}
diff --git a/LLama/Native/LLamaSeqId.cs b/LLama/Native/LLamaSeqId.cs
index 191a6b5ec..bcee74f13 100644
--- a/LLama/Native/LLamaSeqId.cs
+++ b/LLama/Native/LLamaSeqId.cs
@@ -6,7 +6,7 @@ namespace LLama.Native;
/// ID for a sequence in a batch
///
[StructLayout(LayoutKind.Sequential)]
-public struct LLamaSeqId
+public record struct LLamaSeqId
{
///
/// The raw value
@@ -17,7 +17,7 @@ public struct LLamaSeqId
/// Create a new LLamaSeqId
///
///
- public LLamaSeqId(int value)
+ private LLamaSeqId(int value)
{
Value = value;
}
diff --git a/LLama/Native/LLamaTokenType.cs b/LLama/Native/LLamaTokenType.cs
new file mode 100644
index 000000000..171e782ae
--- /dev/null
+++ b/LLama/Native/LLamaTokenType.cs
@@ -0,0 +1,12 @@
+namespace LLama.Native;
+
+public enum LLamaTokenType
+{
+ LLAMA_TOKEN_TYPE_UNDEFINED = 0,
+ LLAMA_TOKEN_TYPE_NORMAL = 1,
+ LLAMA_TOKEN_TYPE_UNKNOWN = 2,
+ LLAMA_TOKEN_TYPE_CONTROL = 3,
+ LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
+ LLAMA_TOKEN_TYPE_UNUSED = 5,
+ LLAMA_TOKEN_TYPE_BYTE = 6,
+}
\ No newline at end of file
diff --git a/LLama/Native/NativeApi.BeamSearch.cs b/LLama/Native/NativeApi.BeamSearch.cs
index 1049dbe3a..142b997bb 100644
--- a/LLama/Native/NativeApi.BeamSearch.cs
+++ b/LLama/Native/NativeApi.BeamSearch.cs
@@ -3,7 +3,7 @@
namespace LLama.Native;
-public partial class NativeApi
+public static partial class NativeApi
{
///
/// Type of pointer to the beam_search_callback function.
diff --git a/LLama/Native/NativeApi.Grammar.cs b/LLama/Native/NativeApi.Grammar.cs
index 84e298c7d..4d47872b5 100644
--- a/LLama/Native/NativeApi.Grammar.cs
+++ b/LLama/Native/NativeApi.Grammar.cs
@@ -5,7 +5,7 @@ namespace LLama.Native
{
using llama_token = Int32;
- public unsafe partial class NativeApi
+ public static partial class NativeApi
{
///
/// Create a new grammar from the given set of grammar rules
@@ -15,7 +15,7 @@ public unsafe partial class NativeApi
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);
+ public static extern unsafe IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);
///
/// Free all memory from the given SafeLLamaGrammarHandle
diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs
index d8a887252..5ae02c1ac 100644
--- a/LLama/Native/NativeApi.Load.cs
+++ b/LLama/Native/NativeApi.Load.cs
@@ -4,13 +4,12 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
-using System.Linq;
using System.Runtime.InteropServices;
using System.Text.Json;
namespace LLama.Native
{
- public partial class NativeApi
+ public static partial class NativeApi
{
static NativeApi()
{
@@ -97,22 +96,13 @@ private static int GetCudaMajorVersion()
}
if (string.IsNullOrEmpty(version))
- {
return -1;
- }
- else
- {
- version = version.Split('.')[0];
- bool success = int.TryParse(version, out var majorVersion);
- if (success)
- {
- return majorVersion;
- }
- else
- {
- return -1;
- }
- }
+
+ version = version.Split('.')[0];
+ if (int.TryParse(version, out var majorVersion))
+ return majorVersion;
+
+ return -1;
}
private static string GetCudaVersionFromPath(string cudaPath)
@@ -129,7 +119,7 @@ private static string GetCudaVersionFromPath(string cudaPath)
{
return string.Empty;
}
- return versionNode.GetString();
+ return versionNode.GetString() ?? "";
}
}
catch (Exception)
@@ -169,18 +159,14 @@ private static List GetLibraryTryOrder(NativeLibraryConfig.Description c
{
platform = OSPlatform.OSX;
suffix = ".dylib";
- if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported)
- {
- prefix = "runtimes/osx-arm64/native/";
- }
- else
- {
- prefix = "runtimes/osx-x64/native/";
- }
+
+ prefix = System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported
+ ? "runtimes/osx-arm64/native/"
+ : "runtimes/osx-x64/native/";
}
else
{
- throw new RuntimeError($"Your system plarform is not supported, please open an issue in LLamaSharp.");
+ throw new RuntimeError("Your system plarform is not supported, please open an issue in LLamaSharp.");
}
Log($"Detected OS Platform: {platform}", LogLevel.Information);
@@ -275,15 +261,15 @@ private static IntPtr TryLoadLibrary()
var libraryTryLoadOrder = GetLibraryTryOrder(configuration);
- string[] preferredPaths = configuration.SearchDirectories;
- string[] possiblePathPrefix = new string[] {
- System.AppDomain.CurrentDomain.BaseDirectory,
+ var preferredPaths = configuration.SearchDirectories;
+ var possiblePathPrefix = new[] {
+ AppDomain.CurrentDomain.BaseDirectory,
Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
};
- var tryFindPath = (string filename) =>
+ string TryFindPath(string filename)
{
- foreach(var path in preferredPaths)
+ foreach (var path in preferredPaths)
{
if (File.Exists(Path.Combine(path, filename)))
{
@@ -291,7 +277,7 @@ private static IntPtr TryLoadLibrary()
}
}
- foreach(var path in possiblePathPrefix)
+ foreach (var path in possiblePathPrefix)
{
if (File.Exists(Path.Combine(path, filename)))
{
@@ -300,21 +286,19 @@ private static IntPtr TryLoadLibrary()
}
return filename;
- };
+ }
foreach (var libraryPath in libraryTryLoadOrder)
{
- var fullPath = tryFindPath(libraryPath);
+ var fullPath = TryFindPath(libraryPath);
var result = TryLoad(fullPath, true);
if (result is not null && result != IntPtr.Zero)
{
Log($"{fullPath} is selected and loaded successfully.", LogLevel.Information);
- return result ?? IntPtr.Zero;
- }
- else
- {
- Log($"Tried to load {fullPath} but failed.", LogLevel.Information);
+ return (IntPtr)result;
}
+
+ Log($"Tried to load {fullPath} but failed.", LogLevel.Information);
}
if (!configuration.AllowFallback)
diff --git a/LLama/Native/NativeApi.Quantize.cs b/LLama/Native/NativeApi.Quantize.cs
index d4ff5cf80..b849e38d5 100644
--- a/LLama/Native/NativeApi.Quantize.cs
+++ b/LLama/Native/NativeApi.Quantize.cs
@@ -2,7 +2,7 @@
namespace LLama.Native
{
- public partial class NativeApi
+ public static partial class NativeApi
{
///
/// Returns 0 on success
diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs
index 9e7d375b6..53a6dd233 100644
--- a/LLama/Native/NativeApi.Sampling.cs
+++ b/LLama/Native/NativeApi.Sampling.cs
@@ -5,7 +5,7 @@ namespace LLama.Native
{
using llama_token = Int32;
- public unsafe partial class NativeApi
+ public static partial class NativeApi
{
///
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
@@ -19,7 +19,7 @@ public unsafe partial class NativeApi
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx,
+ public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx,
ref LLamaTokenDataArrayNative candidates,
llama_token* last_tokens, ulong last_tokens_size,
float penalty_repeat,
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 24b9f571d..1c7715f66 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -9,17 +9,6 @@ namespace LLama.Native
{
using llama_token = Int32;
- public enum LLamaTokenType
- {
- LLAMA_TOKEN_TYPE_UNDEFINED = 0,
- LLAMA_TOKEN_TYPE_NORMAL = 1,
- LLAMA_TOKEN_TYPE_UNKNOWN = 2,
- LLAMA_TOKEN_TYPE_CONTROL = 3,
- LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
- LLAMA_TOKEN_TYPE_UNUSED = 5,
- LLAMA_TOKEN_TYPE_BYTE = 6,
- }
-
///
/// Callback from llama.cpp with log messages
///
@@ -30,7 +19,7 @@ public enum LLamaTokenType
///
/// Direct translation of the llama.cpp API
///
- public unsafe partial class NativeApi
+ public static partial class NativeApi
{
///
/// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded.
@@ -165,7 +154,7 @@ public unsafe partial class NativeApi
///
/// the number of bytes copied
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest);
+ public static extern unsafe ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest);
///
/// Set the state reading from the specified address
@@ -174,7 +163,7 @@ public unsafe partial class NativeApi
///
/// the number of bytes read
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src);
+ public static extern unsafe ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src);
///
/// Load session file
@@ -186,7 +175,7 @@ public unsafe partial class NativeApi
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out);
+ public static extern unsafe bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out);
///
/// Save session file
@@ -211,7 +200,7 @@ public unsafe partial class NativeApi
/// Returns 0 on success
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[Obsolete("use llama_decode() instead")]
- public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past);
+ public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past);
///
/// Convert the provided text into tokens.
@@ -228,34 +217,37 @@ public unsafe partial class NativeApi
///
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special)
{
- // Calculate number of bytes in text and borrow an array that large (+1 for nul byte)
- var byteCount = encoding.GetByteCount(text);
- var array = ArrayPool.Shared.Rent(byteCount + 1);
- try
+ unsafe
{
- // Convert to bytes
- fixed (char* textPtr = text)
- fixed (byte* arrayPtr = array)
+ // Calculate number of bytes in text and borrow an array that large (+1 for nul byte)
+ var byteCount = encoding.GetByteCount(text);
+ var array = ArrayPool.Shared.Rent(byteCount + 1);
+ try
{
- encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length);
+ // Convert to bytes
+ fixed (char* textPtr = text)
+ fixed (byte* arrayPtr = array)
+ {
+ encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length);
+ }
+
+ // Add a zero byte to the end to terminate the string
+ array[byteCount] = 0;
+
+ // Do the actual tokenization
+ fixed (byte* arrayPtr = array)
+ fixed (llama_token* tokensPtr = tokens)
+ return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special);
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(array);
}
-
- // Add a zero byte to the end to terminate the string
- array[byteCount] = 0;
-
- // Do the actual tokenization
- fixed (byte* arrayPtr = array)
- fixed (llama_token* tokensPtr = tokens)
- return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special);
- }
- finally
- {
- ArrayPool.Shared.Return(array);
}
}
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token);
+ public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token);
@@ -281,7 +273,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern float* llama_get_logits(SafeLLamaContextHandle ctx);
+ public static extern unsafe float* llama_get_logits(SafeLLamaContextHandle ctx);
///
/// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
@@ -290,16 +282,24 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);
+ public static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);
///
/// Get the embeddings for the input
- /// shape: [n_embd] (1-dimensional)
///
///
///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx);
+ public static Span llama_get_embeddings(SafeLLamaContextHandle ctx)
+ {
+ unsafe
+ {
+ var ptr = llama_get_embeddings_native(ctx);
+ return new Span(ptr, ctx.EmbeddingSize);
+ }
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")]
+ static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx);
+ }
///
/// Get the "Beginning of sentence" token
@@ -426,7 +426,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
///
/// The length of the string on success, or -1 on failure
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);
+ public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);
///
/// Get the number of metadata key/value pairs
@@ -445,7 +445,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
///
/// The length of the string on success, or -1 on failure
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
+ public static extern unsafe int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
///
/// Get metadata value as a string by index
@@ -456,7 +456,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
///
/// The length of the string on success, or -1 on failure
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
+ public static extern unsafe int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
///
/// Get a string describing the model type
@@ -466,7 +466,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
///
/// The length of the string on success, or -1 on failure
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size);
+ public static extern unsafe int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size);
///
/// Get the size of the model in bytes
@@ -493,7 +493,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
/// size of the buffer
/// The length written, or if the buffer is too small a negative that indicates the length required
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
+ public static extern unsafe int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
///
/// Convert text into tokens
@@ -509,7 +509,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
/// Returns a negative number on failure - the number of tokens that would have been returned
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special);
+ public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special);
///
/// Register a callback to receive llama log messages
diff --git a/LLama/Native/NativeLibraryConfig.cs b/LLama/Native/NativeLibraryConfig.cs
index aaa328c91..ad52fc816 100644
--- a/LLama/Native/NativeLibraryConfig.cs
+++ b/LLama/Native/NativeLibraryConfig.cs
@@ -29,10 +29,11 @@ public sealed class NativeLibraryConfig
private bool _allowFallback = true;
private bool _skipCheck = false;
private bool _logging = false;
+
///
/// search directory -> priority level, 0 is the lowest.
///
- private List _searchDirectories = new List();
+ private readonly List _searchDirectories = new List();
private static void ThrowIfLoaded()
{
@@ -159,9 +160,8 @@ public NativeLibraryConfig WithSearchDirectory(string directory)
internal static Description CheckAndGatherDescription()
{
if (Instance._allowFallback && Instance._skipCheck)
- {
throw new ArgumentException("Cannot skip the check when fallback is allowed.");
- }
+
return new Description(
Instance._libraryPath,
Instance._useCuda,
@@ -169,7 +169,8 @@ internal static Description CheckAndGatherDescription()
Instance._allowFallback,
Instance._skipCheck,
Instance._logging,
- Instance._searchDirectories.Concat(new string[] { "./" }).ToArray());
+ Instance._searchDirectories.Concat(new[] { "./" }).ToArray()
+ );
}
internal static string AvxLevelToString(AvxLevel level)
@@ -204,7 +205,9 @@ private static bool CheckAVX512()
if (!System.Runtime.Intrinsics.X86.X86Base.IsSupported)
return false;
+ // ReSharper disable UnusedVariable (ebx is used when < NET8)
var (_, ebx, ecx, _) = System.Runtime.Intrinsics.X86.X86Base.CpuId(7, 0);
+ // ReSharper restore UnusedVariable
var vnni = (ecx & 0b_1000_0000_0000) != 0;
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index df33076f9..98b510783 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -1,6 +1,5 @@
using System;
using System.Buffers;
-using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
@@ -51,8 +50,6 @@ public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model)
_model.DangerousAddRef(ref success);
if (!success)
throw new RuntimeError("Failed to increment model refcount");
-
-
}
///
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index d62e50417..2280250ec 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -214,7 +214,6 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
/// Get the metadata key for the given index
///
/// The index to get
- /// A temporary buffer to store key characters in. Must be large enough to contain the key.
/// The key, null if there is no such key or if the buffer was too small
public Memory? MetadataKeyByIndex(int index)
{
@@ -243,7 +242,6 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
/// Get the metadata value for the given index
///
/// The index to get
- /// A temporary buffer to store value characters in. Must be large enough to contain the value.
/// The value, null if there is no such value or if the buffer was too small
public Memory? MetadataValueByIndex(int index)
{