From b893c6f6091313fc4db3182c206f2246778588f6 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Thu, 9 Nov 2023 03:27:21 +0800 Subject: [PATCH] feat: add detection template for cuda and avx. --- LLama/LLamaSharp.csproj | 1 + LLama/Native/NativeApi.cs | 155 ++++++++++++++++++++++++++++---------- 2 files changed, 115 insertions(+), 41 deletions(-) diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index d525202f2..4577515b1 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -47,6 +47,7 @@ + diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index e3b182bd4..9998020ef 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -1,8 +1,13 @@ using System; using System.Buffers; +using System.Reflection; using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; +using ManagedCuda; +#if NET6_0_OR_GREATER +using System.Runtime.Intrinsics.X86; +#endif #pragma warning disable IDE1006 // Naming Styles @@ -24,8 +29,9 @@ public unsafe partial class NativeApi { static NativeApi() { - // Try to load a preferred library, based on CPU feature detection - TryLoadLibrary(); +#if NET6_0_OR_GREATER + NativeLibrary.SetDllImportResolver(typeof(NativeApi).Assembly, LLamaImportResolver); +#endif try { @@ -44,61 +50,120 @@ static NativeApi() } /// - /// Try to load libllama, using CPU feature detection to try and load a more specialised DLL if possible + /// Get the cuda version if possible. /// - /// The library handle to unload later, or IntPtr.Zero if no library was loaded - private static IntPtr TryLoadLibrary() + /// -1 for no cuda + private static int GetCudaVersion() { -#if NET6_0_OR_GREATER - - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + int deviceCount = CudaContext.GetDeviceCount(); + for (int deviceIndex = 0; deviceIndex < deviceCount; deviceIndex++) { - // All of the Windows libraries, in order of preference - return TryLoad("cu12.1.0/libllama.dll") - ?? TryLoad("cu11.7.1/libllama.dll") -#if NET8_0_OR_GREATER - ?? TryLoad("avx512/libllama.dll", System.Runtime.Intrinsics.X86.Avx512.IsSupported) -#endif - ?? TryLoad("avx2/libllama.dll", System.Runtime.Intrinsics.X86.Avx2.IsSupported) - ?? TryLoad("avx/libllama.dll", System.Runtime.Intrinsics.X86.Avx.IsSupported) - ?? IntPtr.Zero; + using (CudaContext ctx = new CudaContext(deviceIndex)) + { + var version = ctx.GetAPIVersionOfCurrentContext(); + return version.Major; + } } + return -1; + } - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - // All of the Linux libraries, in order of preference - return TryLoad("cu12.1.0/libllama.so") - ?? TryLoad("cu11.7.1/libllama.so") + /// + /// Get the xla flag for native library name. + /// + /// + private static string GetAvxFlag() + { + AvxLevel level = AvxLevel.None; +#if NET6_0_OR_GREATER + if (Avx.IsSupported) level = AvxLevel.Avx; + if (Avx2.IsSupported) level = AvxLevel.Avx2; #if NET8_0_OR_GREATER - ?? TryLoad("avx512/libllama.so", System.Runtime.Intrinsics.X86.Avx512.IsSupported) + if(Avx512F.IsSupported) level = AvxLevel.Avx512; #endif - ?? TryLoad("avx2/libllama.so", System.Runtime.Intrinsics.X86.Avx2.IsSupported) - ?? TryLoad("avx/libllama.so", System.Runtime.Intrinsics.X86.Avx.IsSupported) - ?? IntPtr.Zero; - } - if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + return level switch { - return IntPtr.Zero; - } + AvxLevel.None => "", + AvxLevel.Avx => "-avx", + AvxLevel.Avx2 => "-avx2", + AvxLevel.Avx512 => "-avx512", + }; +#else + return string.Empty; #endif - - return IntPtr.Zero; + } #if NET6_0_OR_GREATER - // Try to load a DLL from the path if supported. Returns null if nothing is loaded. - static IntPtr? TryLoad(string path, bool supported = true) + private static IntPtr LLamaImportResolver(string name, Assembly assembly, DllImportSearchPath? searchPath) + { + IntPtr handle = IntPtr.Zero; + if(!name.Equals(libraryName)) { - if (!supported) - return null; - - if (NativeLibrary.TryLoad(path, out var handle)) - return handle; + return NativeLibrary.Load(name, assembly, searchPath); + } - return null; + string libraryPath = string.Empty; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var avxFlag = GetAvxFlag(); + // check cuda + var cudaVersion = GetCudaVersion(); + if(cudaVersion == 11) + { + libraryPath = $"runtimes/win-x64/native/libllama-cuda11{avxFlag}.dll"; + } + else if (cudaVersion == 12) + { + libraryPath = $"runtimes/win-x64/native/libllama-cuda12{avxFlag}.dll"; + } + else if(cudaVersion == -1) // cpu version + { + libraryPath = $"runtimes/win-x64/native/libllama{avxFlag}.dll"; + } + else + { + throw new NotImplementedException($"Cuda version {cudaVersion} has not been supported, please compile dll yourself or open an issue in LLamaSharp."); + } } -#endif + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + var avxFlag = GetAvxFlag(); + // check cuda + var cudaVersion = GetCudaVersion(); + if (cudaVersion == 11) + { + libraryPath = $"runtimes/linux-x64/native/libllama-cuda11{avxFlag}.so"; + } + else if (cudaVersion == 12) + { + libraryPath = $"runtimes/linux-x64/native/libllama-cuda12{avxFlag}.so"; + } + else if (cudaVersion == -1) // cpu version + { + libraryPath = $"runtimes/linux-x64/native/libllama{avxFlag}.so"; + } + else + { + throw new NotImplementedException($"Cuda version {cudaVersion} has not been supported, please compile dll yourself or open an issue in LLamaSharp."); + } + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported) + { + libraryPath = $"runtimes/osx-arm64/native/libllama.dylib"; + } + else + { + libraryPath = $"runtimes/osx-x64/native/libllama.dylib"; + } + } + + NativeLibrary.TryLoad(libraryPath, assembly, searchPath, out handle); + return handle; } +#endif + private const string libraryName = "libllama"; @@ -572,5 +637,13 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); + + private enum AvxLevel + { + None = 0, + Avx = 1, + Avx2 = 2, + Avx512 = 3 + } } }