From 1b13f7c7178339785f2f34945efcb5cb59942033 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 16 Dec 2023 15:46:28 +0000 Subject: [PATCH] Improved support for AVX512: - Enabled more features in build process (VBMI and VNNI) - Added runtime checking for this features - Improved runtime checking to no longer require dotnet8.0 --- .github/workflows/compile.yml | 2 +- LLama/Native/NativeLibraryConfig.cs | 27 ++++++++++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml index a8169be12..10b7ff6c7 100644 --- a/.github/workflows/compile.yml +++ b/.github/workflows/compile.yml @@ -57,7 +57,7 @@ jobs: - build: 'avx' defines: '-DLLAMA_AVX2=OFF' - build: 'avx512' - defines: '-DLLAMA_AVX512=ON' + defines: '-DLLAMA_AVX512=ON -LLAMA_AVX512_VBMI=ON -DLLAMA_AVX512_VNNI=ON' runs-on: windows-latest steps: - uses: actions/checkout@v3 diff --git a/LLama/Native/NativeLibraryConfig.cs b/LLama/Native/NativeLibraryConfig.cs index e51359707..aaa328c91 100644 --- a/LLama/Native/NativeLibraryConfig.cs +++ b/LLama/Native/NativeLibraryConfig.cs @@ -194,10 +194,31 @@ private NativeLibraryConfig() _avxLevel = AvxLevel.Avx; if (System.Runtime.Intrinsics.X86.Avx2.IsSupported) _avxLevel = AvxLevel.Avx2; -#if NET8_0_OR_GREATER - if (System.Runtime.Intrinsics.X86.Avx512F.IsSupported) + + if (CheckAVX512()) _avxLevel = AvxLevel.Avx512; + } + + private static bool CheckAVX512() + { + if (!System.Runtime.Intrinsics.X86.X86Base.IsSupported) + return false; + + var (_, ebx, ecx, _) = System.Runtime.Intrinsics.X86.X86Base.CpuId(7, 0); + + var vnni = (ecx & 0b_1000_0000_0000) != 0; + +#if NET8_0_OR_GREATER + var f = System.Runtime.Intrinsics.X86.Avx512F.IsSupported; + var bw = System.Runtime.Intrinsics.X86.Avx512BW.IsSupported; + var vbmi = System.Runtime.Intrinsics.X86.Avx512Vbmi.IsSupported; +#else + var f = (ebx & (1 << 16)) != 0; + var bw = (ebx & (1 << 30)) != 0; + var vbmi = (ecx & 0b_0000_0000_0010) != 0; #endif + + return vnni && vbmi && bw && f; } /// @@ -253,4 +274,4 @@ public override string ToString() } } #endif -} + }