Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="ManagedCuda" Version="10.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="7.0.1" />
</ItemGroup>

Expand Down
155 changes: 114 additions & 41 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
using System;
using System.Buffers;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
using ManagedCuda;
#if NET6_0_OR_GREATER
using System.Runtime.Intrinsics.X86;
#endif

#pragma warning disable IDE1006 // Naming Styles

Expand All @@ -24,8 +29,9 @@ public unsafe partial class NativeApi
{
static NativeApi()
{
// Try to load a preferred library, based on CPU feature detection
TryLoadLibrary();
#if NET6_0_OR_GREATER
NativeLibrary.SetDllImportResolver(typeof(NativeApi).Assembly, LLamaImportResolver);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deliberately avoided using NativeLibrary.SetDllImportResolver when I originally set this up, since it overwrites global configuration that may already be in use by the application.

#endif

try
{
Expand All @@ -44,61 +50,120 @@ static NativeApi()
}

/// <summary>
/// Try to load libllama, using CPU feature detection to try and load a more specialised DLL if possible
/// Get the cuda version if possible.
/// </summary>
/// <returns>The library handle to unload later, or IntPtr.Zero if no library was loaded</returns>
private static IntPtr TryLoadLibrary()
/// <returns> -1 for no cuda</returns>
private static int GetCudaVersion()
{
#if NET6_0_OR_GREATER

if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
int deviceCount = CudaContext.GetDeviceCount();
for (int deviceIndex = 0; deviceIndex < deviceCount; deviceIndex++)
{
// All of the Windows libraries, in order of preference
return TryLoad("cu12.1.0/libllama.dll")
?? TryLoad("cu11.7.1/libllama.dll")
#if NET8_0_OR_GREATER
?? TryLoad("avx512/libllama.dll", System.Runtime.Intrinsics.X86.Avx512.IsSupported)
#endif
?? TryLoad("avx2/libllama.dll", System.Runtime.Intrinsics.X86.Avx2.IsSupported)
?? TryLoad("avx/libllama.dll", System.Runtime.Intrinsics.X86.Avx.IsSupported)
?? IntPtr.Zero;
using (CudaContext ctx = new CudaContext(deviceIndex))
{
var version = ctx.GetAPIVersionOfCurrentContext();
return version.Major;
}
}
return -1;
}

if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
// All of the Linux libraries, in order of preference
return TryLoad("cu12.1.0/libllama.so")
?? TryLoad("cu11.7.1/libllama.so")
/// <summary>
/// Get the xla flag for native library name.
/// </summary>
/// <returns></returns>
private static string GetAvxFlag()
{
AvxLevel level = AvxLevel.None;
#if NET6_0_OR_GREATER
if (Avx.IsSupported) level = AvxLevel.Avx;
if (Avx2.IsSupported) level = AvxLevel.Avx2;
#if NET8_0_OR_GREATER
?? TryLoad("avx512/libllama.so", System.Runtime.Intrinsics.X86.Avx512.IsSupported)
if(Avx512F.IsSupported) level = AvxLevel.Avx512;
#endif
?? TryLoad("avx2/libllama.so", System.Runtime.Intrinsics.X86.Avx2.IsSupported)
?? TryLoad("avx/libllama.so", System.Runtime.Intrinsics.X86.Avx.IsSupported)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AVX/AVX2/AVX512 detection should already be handled by this

?? IntPtr.Zero;
}

if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
return level switch
{
return IntPtr.Zero;
}
AvxLevel.None => "",
AvxLevel.Avx => "-avx",
AvxLevel.Avx2 => "-avx2",
AvxLevel.Avx512 => "-avx512",
};
#else
return string.Empty;
#endif

return IntPtr.Zero;
}

#if NET6_0_OR_GREATER
// Try to load a DLL from the path if supported. Returns null if nothing is loaded.
static IntPtr? TryLoad(string path, bool supported = true)
private static IntPtr LLamaImportResolver(string name, Assembly assembly, DllImportSearchPath? searchPath)
{
IntPtr handle = IntPtr.Zero;
if(!name.Equals(libraryName))
{
if (!supported)
return null;

if (NativeLibrary.TryLoad(path, out var handle))
return handle;
return NativeLibrary.Load(name, assembly, searchPath);
}

return null;
string libraryPath = string.Empty;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
var avxFlag = GetAvxFlag();
// check cuda
var cudaVersion = GetCudaVersion();
if(cudaVersion == 11)
{
libraryPath = $"runtimes/win-x64/native/libllama-cuda11{avxFlag}.dll";
}
else if (cudaVersion == 12)
{
libraryPath = $"runtimes/win-x64/native/libllama-cuda12{avxFlag}.dll";
}
else if(cudaVersion == -1) // cpu version
{
libraryPath = $"runtimes/win-x64/native/libllama{avxFlag}.dll";
}
else
{
throw new NotImplementedException($"Cuda version {cudaVersion} has not been supported, please compile dll yourself or open an issue in LLamaSharp.");
}
}
#endif
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
var avxFlag = GetAvxFlag();
// check cuda
var cudaVersion = GetCudaVersion();
if (cudaVersion == 11)
{
libraryPath = $"runtimes/linux-x64/native/libllama-cuda11{avxFlag}.so";
}
else if (cudaVersion == 12)
{
libraryPath = $"runtimes/linux-x64/native/libllama-cuda12{avxFlag}.so";
}
else if (cudaVersion == -1) // cpu version
{
libraryPath = $"runtimes/linux-x64/native/libllama{avxFlag}.so";
}
else
{
throw new NotImplementedException($"Cuda version {cudaVersion} has not been supported, please compile dll yourself or open an issue in LLamaSharp.");
}
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported)
{
libraryPath = $"runtimes/osx-arm64/native/libllama.dylib";
}
else
{
libraryPath = $"runtimes/osx-x64/native/libllama.dylib";
}
}

NativeLibrary.TryLoad(libraryPath, assembly, searchPath, out handle);
return handle;
}
#endif


private const string libraryName = "libllama";

Expand Down Expand Up @@ -572,5 +637,13 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);

private enum AvxLevel
{
None = 0,
Avx = 1,
Avx2 = 2,
Avx512 = 3
}
}
}