Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LLama.Examples/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

Console.WriteLine("======================================================================================================");

NativeLibraryConfig.Default.WithCuda().WithLogs();
NativeLibraryConfig.Instance.WithCuda().WithLogs();

NativeApi.llama_empty_call();
Console.WriteLine();
Expand Down
24 changes: 7 additions & 17 deletions LLama/Native/NativeApi.Load.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Diagnostics;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;

namespace LLama.Native
Expand Down Expand Up @@ -227,24 +226,15 @@ private static List<string> GetLibraryTryOrder(NativeLibraryConfig.Description c
}
else if (platform != OSPlatform.OSX) // in macos there's absolutely no avx
{
#if NET8_0_OR_GREATER
if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx512)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx512, prefix, suffix)));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix)));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix)));
}
else
#endif
if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx2)
{
if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx512)
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx512, prefix, suffix));

if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx2)
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix));

if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx)
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix));
}
else if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix));
}

result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.None, prefix, suffix));
}

Expand Down
156 changes: 73 additions & 83 deletions LLama/Native/NativeLibraryConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,46 @@ namespace LLama.Native
{
#if NET6_0_OR_GREATER
/// <summary>
/// A class about configurations when loading native libraries.
/// Note that it could be configured only once before any call to llama model apis.
/// Allows configuration of the native llama.cpp libraries to load and use.
/// All configuration must be done before using **any** other LLamaSharp methods!
/// </summary>
public class NativeLibraryConfig
public sealed class NativeLibraryConfig
{
private static NativeLibraryConfig? instance;
private static readonly object lockObject = new object();
public static NativeLibraryConfig Default
{
get
{
return GetInstance();
}
}
private static readonly Lazy<NativeLibraryConfig> _instance = new(() => new NativeLibraryConfig());

/// <summary>
/// Get the config instance
/// </summary>
public static NativeLibraryConfig Instance => _instance.Value;

/// <summary>
/// Whether there's already a config for native library.
/// </summary>
public static bool LibraryHasLoaded { get; internal set; } = false;

private string _libraryPath;
private bool _useCuda;
private string _libraryPath = string.Empty;
private bool _useCuda = true;
private AvxLevel _avxLevel;
private bool _allowFallback;
private bool _skipCheck;
private bool _logging;
private bool _allowFallback = true;
private bool _skipCheck = false;
private bool _logging = false;

internal static NativeLibraryConfig GetInstance()
private static void ThrowIfLoaded()
{
if (instance is null)
{
lock (lockObject)
{
if (instance is null)
{
instance = new NativeLibraryConfig();
}
}
}
return instance;
if (LibraryHasLoaded)
throw new InvalidOperationException("NativeLibraryConfig must be configured before using **any** other LLamaSharp methods!");
}

/// <summary>
/// Load a specified native library as backend for LLamaSharp.
/// When this method is called, all the other configurations will be ignored.
/// </summary>
/// <param name="libraryPath"></param>
/// <exception cref="InvalidOperationException"></exception>
/// <exception cref="InvalidOperationException">Thrown if `LibraryHasLoaded` is true.</exception>
public NativeLibraryConfig WithLibrary(string libraryPath)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
ThrowIfLoaded();

_libraryPath = libraryPath;
return this;
}
Expand All @@ -67,13 +53,11 @@ public NativeLibraryConfig WithLibrary(string libraryPath)
/// </summary>
/// <param name="enable"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
/// <exception cref="InvalidOperationException">Thrown if `LibraryHasLoaded` is true.</exception>
public NativeLibraryConfig WithCuda(bool enable = true)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
ThrowIfLoaded();

_useCuda = enable;
return this;
}
Expand All @@ -83,29 +67,25 @@ public NativeLibraryConfig WithCuda(bool enable = true)
/// </summary>
/// <param name="level"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
/// <exception cref="InvalidOperationException">Thrown if `LibraryHasLoaded` is true.</exception>
public NativeLibraryConfig WithAvx(AvxLevel level)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
ThrowIfLoaded();

_avxLevel = level;
return this;
}

/// <summary>
/// Configure whether to allow fallback when there's not match for preffered settings.
/// Configure whether to allow fallback when there's no match for preferred settings.
/// </summary>
/// <param name="enable"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
/// <exception cref="InvalidOperationException">Thrown if `LibraryHasLoaded` is true.</exception>
public NativeLibraryConfig WithAutoFallback(bool enable = true)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
ThrowIfLoaded();

_allowFallback = enable;
return this;
}
Expand All @@ -117,13 +97,11 @@ public NativeLibraryConfig WithAutoFallback(bool enable = true)
/// </summary>
/// <param name="enable"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
/// <exception cref="InvalidOperationException">Thrown if `LibraryHasLoaded` is true.</exception>
public NativeLibraryConfig SkipCheck(bool enable = true)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
ThrowIfLoaded();

_skipCheck = enable;
return this;
}
Expand All @@ -133,24 +111,22 @@ public NativeLibraryConfig SkipCheck(bool enable = true)
/// </summary>
/// <param name="enable"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
/// <exception cref="InvalidOperationException">Thrown if `LibraryHasLoaded` is true.</exception>
public NativeLibraryConfig WithLogs(bool enable = true)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
ThrowIfLoaded();

_logging = enable;
return this;
}

internal static Description CheckAndGatherDescription()
{
if (Default._allowFallback && Default._skipCheck)
if (Instance._allowFallback && Instance._skipCheck)
{
throw new ArgumentException("Cannot skip the check when fallback is allowed.");
}
return new Description(Default._libraryPath, Default._useCuda, Default._avxLevel, Default._allowFallback, Default._skipCheck, Default._logging);
return new Description(Instance._libraryPath, Instance._useCuda, Instance._avxLevel, Instance._allowFallback, Instance._skipCheck, Instance._logging);
}

internal static string AvxLevelToString(AvxLevel level)
Expand All @@ -163,39 +139,53 @@ internal static string AvxLevelToString(AvxLevel level)
#if NET8_0_OR_GREATER
AvxLevel.Avx512 => "avx512"
#endif
_ => throw new ArgumentException($"Cannot recognize Avx level {level}")
_ => throw new ArgumentException($"Unknown AvxLevel '{level}'")
};
}


/// <summary>
/// Private constructor prevents new instances of this class being created
/// </summary>
private NativeLibraryConfig()
{
_libraryPath = string.Empty;
_useCuda = true;
_avxLevel = AvxLevel.Avx2;
_allowFallback = true;
_skipCheck = false;
_logging = false;
// Automatically detect the highest supported AVX level
if (System.Runtime.Intrinsics.X86.Avx.IsSupported)
_avxLevel = AvxLevel.Avx;
if (System.Runtime.Intrinsics.X86.Avx2.IsSupported)
_avxLevel = AvxLevel.Avx2;
#if NET8_0_OR_GREATER
if (System.Runtime.Intrinsics.X86.Avx512.IsSupported)
_avxLevel = AvxLevel.Avx512;
#endif
}

/// <summary>
/// Avx support configuration
/// </summary>
public enum AvxLevel
{
/// <inheritdoc />
None = 0,
/// <inheritdoc />
Avx = 1,
/// <inheritdoc />
Avx2 = 2,
#if NET8_0_OR_GREATER
/// <inheritdoc />
Avx512 = 3,
#endif
/// <summary>
/// No AVX
/// </summary>
None,

/// <summary>
/// Advanced Vector Extensions (supported by most processors after 2011)
/// </summary>
Avx,

/// <summary>
/// AVX2 (supported by most processors after 2013)
/// </summary>
Avx2,

/// <summary>
/// AVX512 (supported by some processors after 2016, not widely supported)
/// </summary>
Avx512,
}
internal record Description(string Path = "", bool UseCuda = true, AvxLevel AvxLevel = AvxLevel.Avx2,
bool AllowFallback = true, bool SkipCheck = false, bool Logging = false);

internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging);
}
#endif
}
}