diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs
index 6f36358d2..9042a60b5 100644
--- a/LLama.Examples/Examples/BatchedDecoding.cs
+++ b/LLama.Examples/Examples/BatchedDecoding.cs
@@ -1,5 +1,6 @@
using System.Diagnostics;
using System.Text;
+using LLama.Abstractions;
using LLama.Common;
using LLama.Native;
@@ -30,6 +31,7 @@ public static async Task Run()
// Load model
var parameters = new ModelParams(modelPath);
+
using var model = LLamaWeights.LoadFromFile(parameters);
// Tokenize prompt
diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj
index d3acbb822..eb70e2189 100644
--- a/LLama.Examples/LLama.Examples.csproj
+++ b/LLama.Examples/LLama.Examples.csproj
@@ -2,7 +2,7 @@
Exe
- net6.0;net7.0;net8.0
+ net6.0;net8.0
enable
enable
AnyCPU;x64
diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs
index 9feb62023..a8b9b39a3 100644
--- a/LLama.Examples/Program.cs
+++ b/LLama.Examples/Program.cs
@@ -7,7 +7,10 @@
Console.WriteLine("======================================================================================================");
-NativeLibraryConfig.Instance.WithCuda().WithLogs();
+NativeLibraryConfig
+ .Instance
+ .WithCuda()
+ .WithLogs();
NativeApi.llama_empty_call();
Console.WriteLine();
diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs
index 34c9c21bb..e52adc41b 100644
--- a/LLama.Unittest/ModelsParamsTests.cs
+++ b/LLama.Unittest/ModelsParamsTests.cs
@@ -1,5 +1,6 @@
using LLama.Common;
using System.Text.Json;
+using LLama.Abstractions;
namespace LLama.Unittest
{
@@ -8,23 +9,39 @@ public class ModelsParamsTests
[Fact]
public void SerializeRoundTripSystemTextJson()
{
+ var options = new JsonSerializerOptions()
+ {
+ WriteIndented = true,
+ };
+
var expected = new ModelParams("abc/123")
{
BatchSize = 17,
ContextSize = 42,
Seed = 42,
GpuLayerCount = 111,
- TensorSplits = { [0] = 3 }
+ TensorSplits = { [0] = 3 },
+ MetadataOverrides =
+ {
+ new MetadataOverride("hello", true),
+ new MetadataOverride("world", 17),
+ new MetadataOverride("cats", 17f),
+ }
};
- var json = JsonSerializer.Serialize(expected);
- var actual = JsonSerializer.Deserialize(json)!;
+ var json = JsonSerializer.Serialize(expected, options);
+ var actual = JsonSerializer.Deserialize(json, options)!;
// Cannot compare splits with default equality, check they are sequence equal and then set to null
- Assert.Equal((IEnumerable)expected.TensorSplits, expected.TensorSplits);
+ Assert.True(expected.TensorSplits.SequenceEqual(actual.TensorSplits));
actual.TensorSplits = null!;
expected.TensorSplits = null!;
+ // Cannot compare overrides with default equality, check they are sequence equal and then set to null
+ Assert.True(expected.MetadataOverrides.SequenceEqual(actual.MetadataOverrides));
+ actual.MetadataOverrides = null!;
+ expected.MetadataOverrides = null!;
+
// Check encoding is the same
var b1 = expected.Encoding.GetBytes("Hello");
var b2 = actual.Encoding.GetBytes("Hello");
@@ -32,35 +49,5 @@ public void SerializeRoundTripSystemTextJson()
Assert.Equal(expected, actual);
}
-
- //[Fact]
- //public void SerializeRoundTripNewtonsoft()
- //{
- // var expected = new ModelParams("abc/123")
- // {
- // BatchSize = 17,
- // ContextSize = 42,
- // Seed = 42,
- // GpuLayerCount = 111,
- // LoraAdapters =
- // {
- // new("abc", 1),
- // new("def", 0)
- // },
- // TensorSplits = { [0] = 3 }
- // };
-
- // var settings = new Newtonsoft.Json.JsonSerializerSettings();
-
- // var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected, settings);
- // var actual = Newtonsoft.Json.JsonConvert.DeserializeObject(json, settings)!;
-
- // // Cannot compare splits with default equality, check they are sequence equal and then set to null
- // Assert.Equal((IEnumerable)expected.TensorSplits, expected.TensorSplits);
- // actual.TensorSplits = null!;
- // expected.TensorSplits = null!;
-
- // Assert.Equal(expected, actual);
- //}
}
}
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index da840fe9d..7b770b389 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -17,106 +17,55 @@ public class ModelOptions
///
public int MaxInstances { get; set; }
- ///
- /// Model context size (n_ctx)
- ///
+ ///
public uint? ContextSize { get; set; }
- ///
- /// the GPU that is used for scratch and small tensors
- ///
+ ///
public int MainGpu { get; set; } = 0;
- ///
- /// if true, reduce VRAM usage at the cost of performance
- ///
- public bool LowVram { get; set; } = false;
-
- ///
- /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
- ///
+ ///
public int GpuLayerCount { get; set; } = 20;
- ///
- /// Seed for the random number generator (seed)
- ///
+ ///
public uint Seed { get; set; } = 1686349486;
- ///
- /// Use f16 instead of f32 for memory kv (memory_f16)
- ///
- public bool UseFp16Memory { get; set; } = true;
-
- ///
- /// Use mmap for faster loads (use_mmap)
- ///
+ ///
public bool UseMemorymap { get; set; } = true;
- ///
- /// Use mlock to keep model in memory (use_mlock)
- ///
+ ///
public bool UseMemoryLock { get; set; } = false;
- ///
- /// Compute perplexity over the prompt (perplexity)
- ///
- public bool Perplexity { get; set; } = false;
-
- ///
- /// Model path (model)
- ///
+ ///
public string ModelPath { get; set; }
- ///
- /// List of LoRAs to apply
- ///
+ ///
public AdapterCollection LoraAdapters { get; set; } = new();
- ///
-
- /// base model path for the lora adapter (lora_base)
- ///
+ ///
public string LoraBase { get; set; } = string.Empty;
- ///
- /// Number of threads (null = autodetect) (n_threads)
- ///
+ ///
public uint? Threads { get; set; }
- ///
- /// Number of threads to use for batch processing (null = autodetect) (n_threads)
- ///
+ ///
public uint? BatchThreads { get; set; }
- ///
- /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
- ///
+ ///
public uint BatchSize { get; set; } = 512;
- ///
- /// Whether to convert eos to newline during the inference.
- ///
- public bool ConvertEosToNewLine { get; set; } = false;
-
- ///
- /// Whether to use embedding mode. (embedding) Note that if this is set to true,
- /// The LLamaModel won't produce text response anymore.
- ///
+ ///
public bool EmbeddingMode { get; set; } = false;
- ///
- /// how split tensors should be distributed across GPUs
- ///
+ ///
public TensorSplitsCollection TensorSplits { get; set; } = new();
- ///
- /// RoPE base frequency
- ///
+ ///
+ public List MetadataOverrides { get; } = new();
+
+ ///
public float? RopeFrequencyBase { get; set; }
- ///
- /// RoPE frequency scaling factor
- ///
+ ///
public float? RopeFrequencyScale { get; set; }
///
@@ -137,19 +86,19 @@ public class ModelOptions
///
public RopeScalingType? YarnScalingType { get; set; }
- ///
- /// Use experimental mul_mat_q kernels
- ///
- public bool MulMatQ { get; set; }
+ ///
+ public GGMLType? TypeK { get; set; }
- ///
- /// The encoding to use for models
- ///
+ ///
+ public GGMLType? TypeV { get; set; }
+
+ ///
+ public bool NoKqvOffload { get; set; }
+
+ ///
public Encoding Encoding { get; set; } = Encoding.UTF8;
- ///
- /// Load vocab only (no weights)
- ///
+ ///
public bool VocabOnly { get; set; }
}
}
\ No newline at end of file
diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs
index da7498524..d09a6a7c6 100644
--- a/LLama/Abstractions/IContextParams.cs
+++ b/LLama/Abstractions/IContextParams.cs
@@ -23,16 +23,6 @@ public interface IContextParams
///
uint Seed { get; set; }
- ///
- /// Use f16 instead of f32 for memory kv (memory_f16)
- ///
- bool UseFp16Memory { get; set; }
-
- ///
- /// Compute perplexity over the prompt (perplexity)
- ///
- bool Perplexity { get; set; }
-
///
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
/// The LLamaModel won't produce text response anymore.
@@ -49,11 +39,6 @@ public interface IContextParams
///
float? RopeFrequencyScale { get; set; }
- ///
- /// Use experimental mul_mat_q kernels
- ///
- bool MulMatQ { get; set; }
-
///
/// The encoding to use for models
///
@@ -70,27 +55,27 @@ public interface IContextParams
uint? BatchThreads { get; set; }
///
- /// YaRN extrapolation mix factor
+ /// YaRN extrapolation mix factor (null = from model)
///
float? YarnExtrapolationFactor { get; set; }
///
- /// YaRN magnitude scaling factor
+ /// YaRN magnitude scaling factor (null = from model)
///
float? YarnAttentionFactor { get; set; }
///
- /// YaRN low correction dim
+ /// YaRN low correction dim (null = from model)
///
float? YarnBetaFast { get; set; }
///
- /// YaRN high correction dim
+ /// YaRN high correction dim (null = from model)
///
float? YarnBetaSlow { get; set; }
///
- /// YaRN original context length
+ /// YaRN original context length (null = from model)
///
uint? YarnOriginalContext { get; set; }
@@ -98,4 +83,19 @@ public interface IContextParams
/// YaRN scaling method to use.
///
RopeScalingType? YarnScalingType { get; set; }
+
+ ///
+ /// Override the type of the K cache
+ ///
+ GGMLType? TypeK { get; set; }
+
+ ///
+ /// Override the type of the V cache
+ ///
+ GGMLType? TypeV { get; set; }
+
+ ///
+ /// Whether to disable offloading the KQV cache to the GPU
+ ///
+ bool NoKqvOffload { get; set; }
}
\ No newline at end of file
diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index 4a3dde7a2..3b1553d5c 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -5,7 +5,6 @@
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
-using LLama.Common;
using LLama.Native;
namespace LLama.Abstractions
@@ -59,6 +58,11 @@ public interface IModelParams
/// base model path for the lora adapter (lora_base)
///
string LoraBase { get; set; }
+
+ ///
+ /// Override specific metadata items in the model
+ ///
+ List MetadataOverrides { get; }
}
///
@@ -105,6 +109,7 @@ public override int GetHashCode()
}
}
+
///
/// A fixed size array to set the tensor splits across multiple GPUs
///
@@ -186,7 +191,7 @@ public class TensorSplitsCollectionConverter
: JsonConverter
{
///
- public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty();
return new TensorSplitsCollection(arr);
@@ -198,4 +203,131 @@ public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value,
JsonSerializer.Serialize(writer, value.Splits, options);
}
}
+
+
+ ///
+ /// An override for a single key/value pair in model metadata
+ ///
+ [JsonConverter(typeof(MetadataOverrideConverter))]
+ public sealed record MetadataOverride
+ {
+ ///
+ /// Get the key being overriden by this override
+ ///
+ public string Key { get; init; }
+
+ internal LLamaModelKvOverrideType Type { get; }
+
+ private readonly int _valueInt;
+ private readonly float _valueFloat;
+ private readonly bool _valueBool;
+
+ ///
+ /// Create a new override for an int key
+ ///
+ ///
+ ///
+ public MetadataOverride(string key, int value)
+ {
+ Key = key;
+ _valueInt = value;
+ Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
+ }
+
+ ///
+ /// Create a new override for a float key
+ ///
+ ///
+ ///
+ public MetadataOverride(string key, float value)
+ {
+ Key = key;
+ _valueFloat = value;
+ Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
+ }
+
+ ///
+ /// Create a new override for a boolean key
+ ///
+ ///
+ ///
+ public MetadataOverride(string key, bool value)
+ {
+ Key = key;
+ _valueBool = value;
+ Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
+ }
+
+ internal void WriteValue(ref LLamaModelMetadataOverride dest)
+ {
+ switch (Type)
+ {
+ case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT:
+ dest.IntValue = _valueInt;
+ break;
+ case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT:
+ dest.FloatValue = _valueFloat;
+ break;
+ case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
+ dest.BoolValue = _valueBool ? -1 : 0;
+ break;
+ default:
+ throw new ArgumentOutOfRangeException();
+ }
+ }
+
+ internal void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options)
+ {
+ switch (Type)
+ {
+ case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT:
+ writer.WriteNumberValue(_valueInt);
+ break;
+ case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT:
+ writer.WriteNumberValue(_valueFloat);
+ break;
+ case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
+ writer.WriteBooleanValue(_valueBool);
+ break;
+ default:
+ throw new ArgumentOutOfRangeException();
+ }
+ }
+ }
+
+ ///
+ /// A JSON converter for
+ ///
+ public class MetadataOverrideConverter
+ : JsonConverter
+ {
+ ///
+ public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ var ktv = JsonSerializer.Deserialize(ref reader, options)!;
+
+ return ((LLamaModelKvOverrideType)ktv.Type) switch
+ {
+ LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()),
+ LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()),
+ LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()),
+ _ => throw new JsonException(),
+ };
+ }
+
+ ///
+ public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
+ {
+ writer.WriteStartObject();
+ {
+ writer.WriteNumber("Type", (int)value.Type);
+ writer.WriteString("Key", value.Key);
+ writer.WritePropertyName("Value");
+ value.WriteValue(writer, options);
+ }
+ writer.WriteEndObject();
+ }
+
+ private record KeyTypeValue(int Type, string Key, JsonElement Value);
+ }
}
\ No newline at end of file
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index cecd655a1..b124b84db 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -1,9 +1,8 @@
using LLama.Abstractions;
-using System;
using System.Text;
-using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Native;
+using System.Collections.Generic;
namespace LLama.Common
{
@@ -25,18 +24,12 @@ public record ModelParams
///
public uint Seed { get; set; } = 0xFFFFFFFF;
- ///
- public bool UseFp16Memory { get; set; } = true;
-
///
public bool UseMemorymap { get; set; } = true;
///
public bool UseMemoryLock { get; set; }
- ///
- public bool Perplexity { get; set; }
-
///
public string ModelPath { get; set; }
@@ -61,6 +54,9 @@ public record ModelParams
///
public TensorSplitsCollection TensorSplits { get; set; } = new();
+ ///
+ public List MetadataOverrides { get; set; } = new();
+
///
public float? RopeFrequencyBase { get; set; }
@@ -86,7 +82,13 @@ public record ModelParams
public RopeScalingType? YarnScalingType { get; set; }
///
- public bool MulMatQ { get; set; }
+ public GGMLType? TypeK { get; set; }
+
+ ///
+ public GGMLType? TypeV { get; set; }
+
+ ///
+ public bool NoKqvOffload { get; set; }
///
public bool VocabOnly { get; set; }
diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs
index bb029c162..212736170 100644
--- a/LLama/Extensions/IContextParamsExtensions.cs
+++ b/LLama/Extensions/IContextParamsExtensions.cs
@@ -24,8 +24,6 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.n_ctx = @params.ContextSize ?? 0;
result.n_batch = @params.BatchSize;
result.seed = @params.Seed;
- result.f16_kv = @params.UseFp16Memory;
- result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
result.rope_freq_base = @params.RopeFrequencyBase ?? 0;
result.rope_freq_scale = @params.RopeFrequencyScale ?? 0;
@@ -38,7 +36,9 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED;
- result.mul_mat_q = @params.MulMatQ;
+ result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
+ result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
+ result.offload_kqv = !@params.NoKqvOffload;
result.n_threads = Threads(@params.Threads);
result.n_threads_batch = Threads(@params.BatchThreads);
diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
index a9c2d10ef..08805d320 100644
--- a/LLama/Extensions/IModelParamsExtensions.cs
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -1,41 +1,82 @@
using System.IO;
using System;
-using System.Buffers;
+using System.Text;
using LLama.Abstractions;
using LLama.Native;
-namespace LLama.Extensions
+namespace LLama.Extensions;
+
+///
+/// Extention methods to the IModelParams interface
+///
+public static class IModelParamsExtensions
{
///
- /// Extention methods to the IModelParams interface
+ /// Convert the given `IModelParams` into a `LLamaModelParams`
///
- public static class IModelParamsExtensions
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLamaModelParams result)
{
- ///
- /// Convert the given `IModelParams` into a `LLamaModelParams`
- ///
- ///
- ///
- ///
- ///
- ///
- public static MemoryHandle ToLlamaModelParams(this IModelParams @params, out LLamaModelParams result)
- {
- result = NativeApi.llama_model_default_params();
+ var disposer = new GroupDisposable();
+
+ result = NativeApi.llama_model_default_params();
- result.main_gpu = @params.MainGpu;
- result.n_gpu_layers = @params.GpuLayerCount;
- result.use_mlock = @params.UseMemoryLock;
- result.use_mmap = @params.UseMemorymap;
- result.vocab_only = @params.VocabOnly;
+ result.main_gpu = @params.MainGpu;
+ result.n_gpu_layers = @params.GpuLayerCount;
+ result.use_mlock = @params.UseMemoryLock;
+ result.use_mmap = @params.UseMemorymap;
+ result.vocab_only = @params.VocabOnly;
- var pin = @params.TensorSplits.Pin();
+ unsafe
+ {
+ result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
+ }
+
+ if (@params.MetadataOverrides.Count == 0)
+ {
+ unsafe
+ {
+ result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
+ }
+ }
+ else
+ {
+ // Allocate enough space for all the override items
+ var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1];
+ var overridesPin = overrides.AsMemory().Pin();
unsafe
{
- result.tensor_split = (float*)pin.Pointer;
+ result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer;
}
- return pin;
+ // Convert each item
+ for (var i = 0; i < @params.MetadataOverrides.Count; i++)
+ {
+ var item = @params.MetadataOverrides[i];
+ var native = new LLamaModelMetadataOverride
+ {
+ Tag = item.Type
+ };
+
+ item.WriteValue(ref native);
+
+ // Convert key to bytes
+ unsafe
+ {
+ fixed (char* srcKey = item.Key)
+ {
+ Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128);
+ }
+ }
+
+ overrides[i] = native;
+ }
}
+
+ return disposer;
}
-}
+}
\ No newline at end of file
diff --git a/LLama/Native/GGMLType.cs b/LLama/Native/GGMLType.cs
new file mode 100644
index 000000000..8bd33211d
--- /dev/null
+++ b/LLama/Native/GGMLType.cs
@@ -0,0 +1,107 @@
+namespace LLama.Native;
+
+///
+/// Possible GGML quantisation types
+///
+public enum GGMLType
+{
+ ///
+ /// Full 32 bit float
+ ///
+ GGML_TYPE_F32 = 0,
+
+ ///
+ /// 16 bit float
+ ///
+ GGML_TYPE_F16 = 1,
+
+ ///
+ /// 4 bit float
+ ///
+ GGML_TYPE_Q4_0 = 2,
+
+ ///
+ /// 4 bit float
+ ///
+ GGML_TYPE_Q4_1 = 3,
+
+ // GGML_TYPE_Q4_2 = 4, support has been removed
+ // GGML_TYPE_Q4_3 (5) support has been removed
+
+ ///
+ /// 5 bit float
+ ///
+ GGML_TYPE_Q5_0 = 6,
+
+ ///
+ /// 5 bit float
+ ///
+ GGML_TYPE_Q5_1 = 7,
+
+ ///
+ /// 8 bit float
+ ///
+ GGML_TYPE_Q8_0 = 8,
+
+ ///
+ /// 8 bit float
+ ///
+ GGML_TYPE_Q8_1 = 9,
+
+ // k-quantizations
+
+ ///
+ /// "type-1" 2-bit quantization in super-blocks containing 16 blocks, each block having 16 weight.
+ /// Block scales and mins are quantized with 4 bits. This ends up effectively using 2.5625 bits per weight (bpw)
+ ///
+ GGML_TYPE_Q2_K = 10,
+
+ ///
+ /// "type-0" 3-bit quantization in super-blocks containing 16 blocks, each block having 16 weights.
+ /// Scales are quantized with 6 bits. This end up using 3.4375 bpw.
+ ///
+ GGML_TYPE_Q3_K = 11,
+
+ ///
+ /// "type-1" 4-bit quantization in super-blocks containing 8 blocks, each block having 32 weights.
+ /// Scales and mins are quantized with 6 bits. This ends up using 4.5 bpw.
+ ///
+ GGML_TYPE_Q4_K = 12,
+
+ ///
+ /// "type-1" 5-bit quantization. Same super-block structure as GGML_TYPE_Q4_K resulting in 5.5 bpw
+ ///
+ GGML_TYPE_Q5_K = 13,
+
+ ///
+ /// "type-0" 6-bit quantization. Super-blocks with 16 blocks, each block having 16 weights.
+ /// Scales are quantized with 8 bits. This ends up using 6.5625 bpw
+ ///
+ GGML_TYPE_Q6_K = 14,
+
+ ///
+ /// "type-0" 8-bit quantization. Only used for quantizing intermediate results.
+ /// The difference to the existing Q8_0 is that the block size is 256. All 2-6 bit dot products are implemented for this quantization type.
+ ///
+ GGML_TYPE_Q8_K = 15,
+
+ ///
+ /// Integer, 8 bit
+ ///
+ GGML_TYPE_I8 = 16,
+
+ ///
+ /// Integer, 16 bit
+ ///
+ GGML_TYPE_I16 = 17,
+
+ ///
+ /// Integer, 32 bit
+ ///
+ GGML_TYPE_I32 = 18,
+
+ ///
+ /// The value of this entry is the count of the number of possible quant types.
+ ///
+ GGML_TYPE_COUNT,
+}
\ No newline at end of file
diff --git a/LLama/Native/GroupDisposable.cs b/LLama/Native/GroupDisposable.cs
new file mode 100644
index 000000000..5238c1ade
--- /dev/null
+++ b/LLama/Native/GroupDisposable.cs
@@ -0,0 +1,57 @@
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+
+namespace LLama.Native;
+
+///
+/// Disposes all contained disposables when this class is disposed
+///
+internal sealed class GroupDisposable
+ : IDisposable
+{
+ private bool _disposed;
+
+ private readonly List _handles = new();
+ private readonly List _disposables = new();
+
+ ///
+ ~GroupDisposable()
+ {
+ Dispose();
+ }
+
+ public MemoryHandle Add(MemoryHandle handle)
+ {
+ if (_disposed)
+ throw new ObjectDisposedException("Cannot add new handle, already disposed");
+ _handles.Add(handle);
+
+ return handle;
+ }
+
+ public T Add(T disposable)
+ where T : class, IDisposable
+ {
+ if (_disposed)
+ throw new ObjectDisposedException("Cannot add new IDisposable, already disposed");
+ _disposables.Add(disposable);
+
+ return disposable;
+ }
+
+ ///
+ public void Dispose()
+ {
+ if (_disposed)
+ return;
+
+ foreach (var memoryHandle in _handles)
+ memoryHandle.Dispose();
+ foreach (var disposable in _disposables)
+ disposable.Dispose();
+
+ _disposed = true;
+ GC.SuppressFinalize(this);
+ }
+}
\ No newline at end of file
diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index c0f2afa29..bfd39ea46 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -56,7 +56,7 @@ public struct LLamaContextParams
///
public float rope_freq_scale;
///
- /// YaRN extrapolation mix factor, NaN = from model
+ /// YaRN extrapolation mix factor, negative = from model
///
public float yarn_ext_factor;
///
@@ -75,36 +75,26 @@ public struct LLamaContextParams
///
/// YaRN original context size
///
- public uint yarn_orig_ctx;
-
+ public uint yarn_orig_ctx;
+
///
- /// if true, use experimental mul_mat_q kernels
+ /// data type for K cache
///
- public bool mul_mat_q
- {
- readonly get => Convert.ToBoolean(_mul_mat_q);
- set => _mul_mat_q = Convert.ToSByte(value);
- }
- private sbyte _mul_mat_q;
+ public GGMLType type_k;
///
- /// use fp16 for KV cache
+ /// data type for V cache
///
- public bool f16_kv
- {
- readonly get => Convert.ToBoolean(_f16_kv);
- set => _f16_kv = Convert.ToSByte(value);
- }
- private sbyte _f16_kv;
+ public GGMLType type_v;
///
- /// the llama_eval() call computes all logits, not just the last one
+ /// Deprecated!
+ ///
+ private sbyte _mul_mat_q;
+
+ ///
+ /// Deprecated!
///
- public bool logits_all
- {
- readonly get => Convert.ToBoolean(_logits_all);
- set => _logits_all = Convert.ToSByte(value);
- }
private sbyte _logits_all;
///
@@ -116,6 +106,16 @@ public bool embedding
set => _embedding = Convert.ToSByte(value);
}
private sbyte _embedding;
+
+ ///
+ /// whether to offload the KQV ops (including the KV cache) to GPU
+ ///
+ public bool offload_kqv
+ {
+ readonly get => Convert.ToBoolean(_offload_kqv);
+ set => _offload_kqv = Convert.ToSByte(value);
+ }
+ private sbyte _offload_kqv;
}
}
diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs
new file mode 100644
index 000000000..ea1c1172c
--- /dev/null
+++ b/LLama/Native/LLamaKvCacheView.cs
@@ -0,0 +1,174 @@
+using System;
+using System.Runtime.InteropServices;
+
+namespace LLama.Native;
+
+///
+/// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
+///
+[StructLayout(LayoutKind.Sequential)]
+public struct LLamaKvCacheViewCell
+{
+ ///
+ /// The position for this cell. Takes KV cache shifts into account.
+ /// May be negative if the cell is not populated.
+ ///
+ public LLamaPos pos;
+};
+
+///
+/// An updateable view of the KV cache (llama_kv_cache_view)
+///
+[StructLayout(LayoutKind.Sequential)]
+public unsafe struct LLamaKvCacheView
+{
+ // Number of KV cache cells. This will be the same as the context size.
+ int n_cells;
+
+ // Maximum number of sequences that can exist in a cell. It's not an error
+ // if there are more sequences in a cell than this value, however they will
+ // not be visible in the view cells_sequences.
+ int n_max_seq;
+
+ // Number of tokens in the cache. For example, if there are two populated
+ // cells, the first with 1 sequence id in it and the second with 2 sequence
+ // ids then you'll have 3 tokens.
+ int token_count;
+
+ // Number of populated cache cells.
+ int used_cells;
+
+ // Maximum contiguous empty slots in the cache.
+ int max_contiguous;
+
+ // Index to the start of the max_contiguous slot range. Can be negative
+ // when cache is full.
+ int max_contiguous_idx;
+
+ // Information for an individual cell.
+ LLamaKvCacheViewCell* cells;
+
+ // The sequences for each cell. There will be n_max_seq items per cell.
+ LLamaSeqId* cells_sequences;
+}
+
+///
+/// A safe handle for a LLamaKvCacheView
+///
+public class LLamaKvCacheViewSafeHandle
+ : SafeLLamaHandleBase
+{
+ private readonly SafeLLamaContextHandle _ctx;
+ private LLamaKvCacheView _view;
+
+ ///
+ /// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
+ ///
+ ///
+ ///
+ public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view)
+ : base((IntPtr)1, true)
+ {
+ _ctx = ctx;
+ _view = view;
+ }
+
+ ///
+ /// Allocate a new llama_kv_cache_view_free
+ ///
+ ///
+ /// The maximum number of sequences visible in this view per cell
+ ///
+ public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences)
+ {
+ var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences);
+ return new LLamaKvCacheViewSafeHandle(ctx, result);
+ }
+
+ ///
+ protected override bool ReleaseHandle()
+ {
+ NativeApi.llama_kv_cache_view_free(ref _view);
+ SetHandle(IntPtr.Zero);
+
+ return true;
+ }
+
+ ///
+ /// Update this view
+ ///
+ public void Update()
+ {
+ NativeApi.llama_kv_cache_view_update(_ctx, ref _view);
+ }
+
+ ///
+ /// Count the number of used cells in the KV cache
+ ///
+ ///
+ public int CountCells()
+ {
+ return NativeApi.llama_get_kv_cache_used_cells(_ctx);
+ }
+
+ ///
+ /// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times
+ ///
+ ///
+ public int CountTokens()
+ {
+ return NativeApi.llama_get_kv_cache_token_count(_ctx);
+ }
+
+ ///
+ /// Get the raw KV cache view
+ ///
+ ///
+ public ref LLamaKvCacheView GetView()
+ {
+ return ref _view;
+ }
+}
+
+partial class NativeApi
+{
+ ///
+ /// Create an empty KV cache view. (use only for debugging purposes)
+ ///
+ ///
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern LLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_max_seq);
+
+ ///
+ /// Free a KV cache view. (use only for debugging purposes)
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view);
+
+ ///
+ /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
+ ///
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view);
+
+ ///
+ /// Returns the number of tokens in the KV cache (slow, use only for debug)
+ /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
+ ///
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx);
+
+ ///
+ /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
+ ///
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx);
+}
\ No newline at end of file
diff --git a/LLama/Native/LLamaModelMetadataOverride.cs b/LLama/Native/LLamaModelMetadataOverride.cs
new file mode 100644
index 000000000..7f685edfc
--- /dev/null
+++ b/LLama/Native/LLamaModelMetadataOverride.cs
@@ -0,0 +1,61 @@
+using System.Runtime.InteropServices;
+
+namespace LLama.Native;
+
+///
+/// Override a key/value pair in the llama model metadata (llama_model_kv_override)
+///
+[StructLayout(LayoutKind.Explicit)]
+public unsafe struct LLamaModelMetadataOverride
+{
+ ///
+ /// Key to override
+ ///
+ [FieldOffset(0)]
+ public fixed byte key[128];
+
+ ///
+ /// Type of value
+ ///
+ [FieldOffset(128)]
+ public LLamaModelKvOverrideType Tag;
+
+ ///
+ /// Value, **must** only be used if Tag == LLAMA_KV_OVERRIDE_INT
+ ///
+ [FieldOffset(132)]
+ public long IntValue;
+
+ ///
+ /// Value, **must** only be used if Tag == LLAMA_KV_OVERRIDE_FLOAT
+ ///
+ [FieldOffset(132)]
+ public double FloatValue;
+
+ ///
+ /// Value, **must** only be used if Tag == LLAMA_KV_OVERRIDE_BOOL
+ ///
+ [FieldOffset(132)]
+ public int BoolValue;
+}
+
+///
+/// Specifies what type of value is being overridden by LLamaModelKvOverride
+///
+public enum LLamaModelKvOverrideType
+{
+ ///
+ /// Overriding an int value
+ ///
+ LLAMA_KV_OVERRIDE_INT = 0,
+
+ ///
+ /// Overriding a float value
+ ///
+ LLAMA_KV_OVERRIDE_FLOAT = 1,
+
+ ///
+ /// Overriding a bool value
+ ///
+ LLAMA_KV_OVERRIDE_BOOL = 2,
+}
\ No newline at end of file
diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs
index a0d307750..ed7b60439 100644
--- a/LLama/Native/LLamaModelParams.cs
+++ b/LLama/Native/LLamaModelParams.cs
@@ -34,6 +34,11 @@ public unsafe struct LLamaModelParams
///
public void* progress_callback_user_data;
+ ///
+ /// override key-value pairs of the model meta data
+ ///
+ public LLamaModelMetadataOverride* kv_overrides;
+
///
/// only load the vocabulary, no weights
///
diff --git a/LLama/runtimes/deps/avx/libllama.dll b/LLama/runtimes/deps/avx/libllama.dll
index 55d574843..954bb194e 100644
Binary files a/LLama/runtimes/deps/avx/libllama.dll and b/LLama/runtimes/deps/avx/libllama.dll differ
diff --git a/LLama/runtimes/deps/avx/libllama.so b/LLama/runtimes/deps/avx/libllama.so
index e9360b95b..4b788e624 100644
Binary files a/LLama/runtimes/deps/avx/libllama.so and b/LLama/runtimes/deps/avx/libllama.so differ
diff --git a/LLama/runtimes/deps/avx2/libllama.dll b/LLama/runtimes/deps/avx2/libllama.dll
index 52330a971..8a0e86c7d 100644
Binary files a/LLama/runtimes/deps/avx2/libllama.dll and b/LLama/runtimes/deps/avx2/libllama.dll differ
diff --git a/LLama/runtimes/deps/avx2/libllama.so b/LLama/runtimes/deps/avx2/libllama.so
index 9f84c424c..c299ee65f 100644
Binary files a/LLama/runtimes/deps/avx2/libllama.so and b/LLama/runtimes/deps/avx2/libllama.so differ
diff --git a/LLama/runtimes/deps/avx512/libllama.dll b/LLama/runtimes/deps/avx512/libllama.dll
index 5f68f81b4..709faf9a1 100644
Binary files a/LLama/runtimes/deps/avx512/libllama.dll and b/LLama/runtimes/deps/avx512/libllama.dll differ
diff --git a/LLama/runtimes/deps/avx512/libllama.so b/LLama/runtimes/deps/avx512/libllama.so
index 2791a7491..e9290e661 100644
Binary files a/LLama/runtimes/deps/avx512/libllama.so and b/LLama/runtimes/deps/avx512/libllama.so differ
diff --git a/LLama/runtimes/deps/cu11.7.1/libllama.dll b/LLama/runtimes/deps/cu11.7.1/libllama.dll
index 8aa06f952..4440d33e2 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/libllama.dll and b/LLama/runtimes/deps/cu11.7.1/libllama.dll differ
diff --git a/LLama/runtimes/deps/cu11.7.1/libllama.so b/LLama/runtimes/deps/cu11.7.1/libllama.so
index 4f98e823b..9bce0d517 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/libllama.so and b/LLama/runtimes/deps/cu11.7.1/libllama.so differ
diff --git a/LLama/runtimes/deps/cu12.1.0/libllama.dll b/LLama/runtimes/deps/cu12.1.0/libllama.dll
index 802e357e8..cab4b10b9 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/libllama.dll and b/LLama/runtimes/deps/cu12.1.0/libllama.dll differ
diff --git a/LLama/runtimes/deps/cu12.1.0/libllama.so b/LLama/runtimes/deps/cu12.1.0/libllama.so
index 5a794f8e7..8b579ed25 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/libllama.so and b/LLama/runtimes/deps/cu12.1.0/libllama.so differ
diff --git a/LLama/runtimes/deps/libllama.dll b/LLama/runtimes/deps/libllama.dll
index a68c94185..2aa3afdc2 100644
Binary files a/LLama/runtimes/deps/libllama.dll and b/LLama/runtimes/deps/libllama.dll differ
diff --git a/LLama/runtimes/deps/libllama.so b/LLama/runtimes/deps/libllama.so
index d0ef8a591..670555d1d 100644
Binary files a/LLama/runtimes/deps/libllama.so and b/LLama/runtimes/deps/libllama.so differ
diff --git a/LLama/runtimes/deps/osx-arm64/ggml-metal.metal b/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
index 5d1357cd7..773fac124 100644
--- a/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
+++ b/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
@@ -3,6 +3,8 @@
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define QK4_0 32
#define QR4_0 2
@@ -39,8 +41,15 @@ typedef struct {
int8_t qs[QK8_0]; // quants
} block_q8_0;
-// general-purpose kernel for addition of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
+enum ggml_sort_order {
+ GGML_SORT_ASC,
+ GGML_SORT_DESC,
+};
+
+// general-purpose kernel for addition, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
// cons: not very efficient
kernel void kernel_add(
device const char * src0,
@@ -81,16 +90,111 @@ kernel void kernel_add(
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
+ }
+}
+
+kernel void kernel_mul(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant int64_t & nb00,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant int64_t & nb10,
+ constant int64_t & nb11,
+ constant int64_t & nb12,
+ constant int64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant int64_t & nb0,
+ constant int64_t & nb1,
+ constant int64_t & nb2,
+ constant int64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
+ }
+}
+
+kernel void kernel_div(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant int64_t & nb00,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant int64_t & nb10,
+ constant int64_t & nb11,
+ constant int64_t & nb12,
+ constant int64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant int64_t & nb0,
+ constant int64_t & nb1,
+ constant int64_t & nb2,
+ constant int64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
- src0_ptr += ntg.x*nb00;
- src1_ptr += ntg.x*nb10;
- dst_ptr += ntg.x*nb0;
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
}
}
@@ -105,23 +209,22 @@ kernel void kernel_add_row(
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
-kernel void kernel_mul(
+kernel void kernel_mul_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
+ constant int64_t & nb [[buffer(27)]],
uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src1[tpig];
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
}
-// assumption: src1 is a row
-// broadcast src1 into src0
-kernel void kernel_mul_row(
+kernel void kernel_div_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant int64_t & nb,
+ constant int64_t & nb [[buffer(27)]],
uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src1[tpig % nb];
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
}
kernel void kernel_scale(
@@ -162,6 +265,54 @@ kernel void kernel_sqr(
dst[tpig] = src0[tpig] * src0[tpig];
}
+kernel void kernel_sum_rows(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant int64_t & nb00,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant int64_t & nb10,
+ constant int64_t & nb11,
+ constant int64_t & nb12,
+ constant int64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant int64_t & nb0,
+ constant int64_t & nb1,
+ constant int64_t & nb2,
+ constant int64_t & nb3,
+ uint3 tpig[[thread_position_in_grid]]) {
+ int64_t i3 = tpig.z;
+ int64_t i2 = tpig.y;
+ int64_t i1 = tpig.x;
+
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
+ }
+
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
+
+ float row_sum = 0;
+
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
+ row_sum += src_row[i0];
+ }
+
+ dst_row[0] = row_sum;
+}
+
constant float GELU_COEF_A = 0.044715f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -180,10 +331,12 @@ kernel void kernel_gelu(
kernel void kernel_soft_max(
device const float * src0,
+ device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
+ constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
@@ -194,73 +347,82 @@ kernel void kernel_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
- float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
+ float lmax = -INFINITY;
- for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
- lmax = MAX(lmax, psrc0[i00]);
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}
- float max = simd_max(lmax);
- if (tiisg == 0) {
- buf[sgitg] = max;
- }
+ // find the max value in the block
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- // broadcast, simd group number is ntg / 32
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
- if (tpitg < i) {
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
- }
- }
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- max = buf[0];
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- const float exp_psrc0 = exp(psrc0[i00] - max);
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
- // Remember the result of exp here. exp is expensive, so we really do not
- // wish to compute it twice.
pdst[i00] = exp_psrc0;
}
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
float sum = simd_sum(lsum);
- if (tiisg == 0) {
- buf[sgitg] = sum;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
- // broadcast, simd group number is ntg / 32
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
- if (tpitg < i) {
- buf[tpitg] += buf[tpitg + i];
- }
- }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
- sum = buf[0];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- pdst[i00] /= sum;
+ pdst[i00] *= inv_sum;
}
}
kernel void kernel_soft_max_4(
device const float * src0,
+ device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
+ constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
@@ -271,64 +433,74 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
- float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
+ float4 lmax4 = -INFINITY;
- for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
- lmax4 = fmax(lmax4, psrc4[i00]);
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
- float max = simd_max(lmax);
- if (tiisg == 0) {
- buf[sgitg] = max;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
- // broadcast, simd group number is ntg / 32
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
- if (tpitg < i) {
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
- }
- }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- max = buf[0];
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- const float4 exp_psrc4 = exp(psrc4[i00] - max);
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
float sum = simd_sum(lsum);
- if (tiisg == 0) {
- buf[sgitg] = sum;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
- // broadcast, simd group number is ntg / 32
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
- if (tpitg < i) {
- buf[tpitg] += buf[tpitg + i];
- }
- }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
- sum = buf[0];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- pdst4[i00] /= sum;
+ pdst4[i00] *= inv_sum;
}
}
@@ -435,14 +607,13 @@ kernel void kernel_rms_norm(
constant int64_t & ne00,
constant uint64_t & nb01,
constant float & eps,
- threadgroup float * sum [[threadgroup(0)]],
+ threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
- device const float * x_scalar = (device const float *) x;
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
float4 sumf = 0;
float all_sum = 0;
@@ -453,40 +624,30 @@ kernel void kernel_rms_norm(
}
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
all_sum = simd_sum(all_sum);
- if (tiisg == 0) {
- sum[sgitg] = all_sum;
- }
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- // broadcast, simd group number is ntg / 32
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
- if (tpitg < i) {
- sum[tpitg] += sum[tpitg + i];
- }
- }
- if (tpitg == 0) {
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {
- sum[0] += x_scalar[i];
+ if (tiisg == 0) {
+ buf[sgitg] = all_sum;
}
- sum[0] /= ne00;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ all_sum = buf[tiisg];
+ all_sum = simd_sum(all_sum);
+ }
- const float mean = sum[0];
+ const float mean = all_sum/ne00;
const float scale = 1.0f/sqrt(mean + eps);
device float4 * y = (device float4 *) (dst + tgpig*ne00);
- device float * y_scalar = (device float *) y;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
y[i00] = x[i00] * scale;
}
- if (tpitg == 0) {
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
- y_scalar[i00] = x_scalar[i00] * scale;
- }
- }
}
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -576,15 +737,25 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
-#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
//Note: This is a template, but strictly speaking it only applies to
// quantizations where the block size is 32. It also does not
// giard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
template
-void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
- uint3 tgpig, uint tiisg, uint sgitg) {
+void mul_vec_q_n_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
@@ -593,7 +764,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
const int first_row = (r0 * nsg + sgitg) * nr;
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
@@ -643,13 +817,14 @@ kernel void kernel_mul_mv_q4_0_f32(
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
@@ -661,13 +836,14 @@ kernel void kernel_mul_mv_q4_1_f32(
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
@@ -679,13 +855,14 @@ kernel void kernel_mul_mv_q5_0_f32(
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
@@ -697,33 +874,35 @@ kernel void kernel_mul_mv_q5_1_f32(
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
#define NB_Q8_0 8
-kernel void kernel_mul_mv_q8_0_f32(
+void kernel_mul_mv_q8_0_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nr = N_DST;
const int nsg = N_SIMDGROUP;
const int nw = N_SIMDWIDTH;
@@ -732,8 +911,14 @@ kernel void kernel_mul_mv_q8_0_f32(
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
+
const int first_row = (r0 * nsg + sgitg) * nr;
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
@@ -771,9 +956,29 @@ kernel void kernel_mul_mv_q8_0_f32(
}
}
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
#define N_F32_F32 4
-kernel void kernel_mul_mv_f32_f32(
+void kernel_mul_mv_f32_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
@@ -791,6 +996,8 @@ kernel void kernel_mul_mv_f32_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
@@ -798,7 +1005,12 @@ kernel void kernel_mul_mv_f32_f32(
const int64_t rb = tgpig.y*N_F32_F32;
const int64_t im = tgpig.z;
- device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const float * x = (device const float *) (src0 + offset0);
if (ne00 < 128) {
for (int row = 0; row < N_F32_F32; ++row) {
@@ -844,7 +1056,33 @@ kernel void kernel_mul_mv_f32_f32(
}
}
-#define N_F16_F16 4
+[[host_name("kernel_mul_mv_f32_f32")]]
+kernel void kernel_mul_mv_f32_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
+#define N_F16_F16 4
kernel void kernel_mul_mv_f16_f16(
device const char * src0,
@@ -864,6 +1102,8 @@ kernel void kernel_mul_mv_f16_f16(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
@@ -871,7 +1111,12 @@ kernel void kernel_mul_mv_f16_f16(
const int64_t rb = tgpig.y*N_F16_F16;
const int64_t im = tgpig.z;
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const half * x = (device const half *) (src0 + offset0);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F16; ++row) {
@@ -917,7 +1162,7 @@ kernel void kernel_mul_mv_f16_f16(
}
}
-kernel void kernel_mul_mv_f16_f32_1row(
+void kernel_mul_mv_f16_f32_1row_impl(
device const char * src0,
device const char * src1,
device float * dst,
@@ -935,6 +1180,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
@@ -942,7 +1189,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const half * x = (device const half *) (src0 + offset0);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
@@ -966,12 +1218,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
+}
+[[host_name("kernel_mul_mv_f16_f32_1row")]]
+kernel void kernel_mul_mv_f16_f32_1row(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
#define N_F16_F32 4
-kernel void kernel_mul_mv_f16_f32(
+void kernel_mul_mv_f16_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
@@ -989,6 +1266,8 @@ kernel void kernel_mul_mv_f16_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
@@ -996,7 +1275,12 @@ kernel void kernel_mul_mv_f16_f32(
const int64_t rb = tgpig.y*N_F16_F32;
const int64_t im = tgpig.z;
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const half * x = (device const half *) (src0 + offset0);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F32; ++row) {
@@ -1042,6 +1326,32 @@ kernel void kernel_mul_mv_f16_f32(
}
}
+[[host_name("kernel_mul_mv_f16_f32")]]
+kernel void kernel_mul_mv_f16_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
// Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mv_f16_f32_l4(
device const char * src0,
@@ -1061,6 +1371,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
@@ -1068,7 +1380,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
const int64_t r0 = tgpig.x;
const int64_t im = tgpig.z;
- device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
@@ -1120,17 +1437,21 @@ kernel void kernel_alibi_f32(
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+ const int64_t k = i3*ne3 + i2;
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
float m_k;
- if (i2 < n_heads_log2_floor) {
- m_k = pow(m0, i2 + 1);
+ if (k < n_heads_log2_floor) {
+ m_k = pow(m0, k + 1);
} else {
- m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
+ m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
}
+
+ device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
+ device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
- dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
+ const float src_v = *(device float *)(src_row + i00*nb00);
+ device float * dst_v = (device float *)(dst_row + i00*nb0);
+ *dst_v = i00 * m_k + src_v;
}
}
@@ -1335,9 +1656,61 @@ kernel void kernel_im2col_f16(
}
}
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template
+kernel void kernel_argsort_f32_i32(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+ // bitonic sort
+ int col = tpitg[0];
+ int row = tgpig[1];
+
+ if (col >= ncols) return;
+
+ device const float * x_row = x + row * ncols;
+ device int32_t * dst_row = dst + row * ncols;
+
+ // initialize indices
+ if (col < ncols) {
+ dst_row[col] = col;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int k = 2; k <= ncols; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ }
+}
+
+template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32;
+
kernel void kernel_cpy_f16_f16(
- device const half * src0,
- device half * dst,
+ device const half * src0,
+ device half * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -1376,6 +1749,47 @@ kernel void kernel_cpy_f16_f16(
}
}
+kernel void kernel_cpy_f16_f32(
+ device const half * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0];
+ }
+}
+
kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,
@@ -1460,106 +1874,297 @@ kernel void kernel_cpy_f32_f32(
}
}
-kernel void kernel_concat(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
+kernel void kernel_cpy_f32_q8_0(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- if (i02 < ne02) {
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
- src0_ptr += ntg.x*nb00;
- } else {
- ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
- src1_ptr += ntg.x*nb10;
- }
- dst_ptr += ntg.x*nb0;
- }
-}
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-//============================================ k-quants ======================================================
+ float amax = 0.0f; // absolute max
-#ifndef QK_K
-#define QK_K 256
-#else
-static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
-#endif
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = src[j];
+ amax = MAX(amax, fabs(v));
+ }
-#if QK_K == 256
-#define K_SCALE_SIZE 12
-#else
-#define K_SCALE_SIZE 4
-#endif
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
-typedef struct {
- uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
- uint8_t qs[QK_K/4]; // quants
- half d; // super-block scale for quantized scales
- half dmin; // super-block scale for quantized mins
-} block_q2_K;
-// 84 bytes / block
+ dst_data[i00/QK8_0].d = d;
-typedef struct {
- uint8_t hmask[QK_K/8]; // quants - high bit
- uint8_t qs[QK_K/4]; // quants - low 2 bits
-#if QK_K == 64
- uint8_t scales[2];
-#else
- uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
-#endif
- half d; // super-block scale
-} block_q3_K;
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = src[j]*id;
-#if QK_K == 64
-typedef struct {
- half d[2]; // super-block scales/mins
- uint8_t scales[2];
- uint8_t qs[QK_K/2]; // 4-bit quants
-} block_q4_K;
-#else
-typedef struct {
- half d; // super-block scale for quantized scales
- half dmin; // super-block scale for quantized mins
- uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+ dst_data[i00/QK8_0].qs[j] = round(x0);
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q4_0(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
+
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK4_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK4_0].d = d;
+
+ for (int j = 0; j < QK4_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+ dst_data[i00/QK4_0].qs[j] = xi0;
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q4_1(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
+
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float min = FLT_MAX;
+ float max = -FLT_MAX;
+
+ for (int j = 0; j < QK4_1; j++) {
+ const float v = src[j];
+ if (min > v) min = v;
+ if (max < v) max = v;
+ }
+
+ const float d = (max - min) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK4_1].d = d;
+ dst_data[i00/QK4_1].m = min;
+
+ for (int j = 0; j < QK4_1/2; ++j) {
+ const float x0 = (src[0 + j] - min)*id;
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+ dst_data[i00/QK4_1].qs[j] = xi0;
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+kernel void kernel_concat(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ if (i02 < ne02) {
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
+ src0_ptr += ntg.x*nb00;
+ } else {
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
+ src1_ptr += ntg.x*nb10;
+ }
+ dst_ptr += ntg.x*nb0;
+ }
+}
+
+//============================================ k-quants ======================================================
+
+#ifndef QK_K
+#define QK_K 256
+#else
+static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
+#endif
+
+#if QK_K == 256
+#define K_SCALE_SIZE 12
+#else
+#define K_SCALE_SIZE 4
+#endif
+
+typedef struct {
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+ uint8_t qs[QK_K/4]; // quants
+ half d; // super-block scale for quantized scales
+ half dmin; // super-block scale for quantized mins
+} block_q2_K;
+// 84 bytes / block
+
+typedef struct {
+ uint8_t hmask[QK_K/8]; // quants - high bit
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
+#if QK_K == 64
+ uint8_t scales[2];
+#else
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+ half d; // super-block scale
+} block_q3_K;
+
+#if QK_K == 64
+typedef struct {
+ half d[2]; // super-block scales/mins
+ uint8_t scales[2];
+ uint8_t qs[QK_K/2]; // 4-bit quants
+} block_q4_K;
+#else
+typedef struct {
+ half d; // super-block scale for quantized scales
+ half dmin; // super-block scale for quantized mins
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
#endif
@@ -1608,32 +2213,39 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
//====================================== dot products =========================
-kernel void kernel_mul_mv_q2_K_f32(
+void kernel_mul_mv_q2_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
- const int r2 = tgpig.z;
+ const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
- const uint offset0 = r2/gqa*(nb*ne0);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
@@ -1642,11 +2254,11 @@ kernel void kernel_mul_mv_q2_K_f32(
#if QK_K == 256
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
- const int im = it/4; // 0 or 1
+ const int iq = it/4; // 0 or 1
const int ir = it%4; // 0...3
const int is = (8*ir)/16;// 0 or 1
- device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
for (int ib = ix; ib < nb; ib += 4) {
@@ -1658,8 +2270,8 @@ kernel void kernel_mul_mv_q2_K_f32(
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
}
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
device const half * dh = &x[ib].d;
for (int row = 0; row < N_DST; row++) {
@@ -1746,13 +2358,13 @@ kernel void kernel_mul_mv_q2_K_f32(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
}
}
}
-#if QK_K == 256
-kernel void kernel_mul_mv_q3_K_f32(
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@@ -1761,23 +2373,50 @@ kernel void kernel_mul_mv_q3_K_f32(
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+#if QK_K == 256
+void kernel_mul_mv_q3_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- const int64_t r2 = tgpig.z;
+ const int64_t im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
- const uint offset0 = r2/gqa*(nb*ne0);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[32];
@@ -1899,40 +2538,47 @@ kernel void kernel_mul_mv_q3_K_f32(
}
if (tiisg == 0) {
for (int row = 0; row < 2; ++row) {
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
}
}
}
#else
-kernel void kernel_mul_mv_q3_K_f32(
+void kernel_mul_mv_q3_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- const int64_t r2 = tgpig.z;
+ const int64_t im = tgpig.z;
const int row = 2 * r0 + sgitg;
- const uint offset0 = r2/gqa*(nb*ne0);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
const int ix = tiisg/4;
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
- const int im = il/8; // 0, 0, 1, 1
+ const int iq = il/8; // 0, 0, 1, 1
const int in = il%8; // 0, 4, 0, 4
float2 sum = {0.f, 0.f};
@@ -1952,7 +2598,7 @@ kernel void kernel_mul_mv_q3_K_f32(
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
for (int l = 0; l < 4; l += 2) {
- const uint16_t hm = h[l/2] >> im;
+ const uint16_t hm = h[l/2] >> iq;
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -1968,28 +2614,50 @@ kernel void kernel_mul_mv_q3_K_f32(
const float tot = simd_sum(sumf);
if (tiisg == 0) {
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
}
}
#endif
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
#if QK_K == 256
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01 [[buffer(4)]],
- constant int64_t & ne02 [[buffer(5)]],
- constant int64_t & ne10 [[buffer(9)]],
- constant int64_t & ne12 [[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & gqa [[buffer(17)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
@@ -1997,26 +2665,32 @@ kernel void kernel_mul_mv_q4_K_f32(
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
- const int im = it/4; // 0 or 1
+ const int iq = it/4; // 0 or 1
const int ir = it%4; // 0...3
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
- const int r2 = tgpig.z;
+ const int im = tgpig.z;
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int first_row = r0 * N_DST;
const int ib_row = first_row * nb;
- const uint offset0 = r2/gqa*(nb*ne0);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
float yl[16];
float yh[16];
float sumf[N_DST]={0.f}, all_sum;
const int step = sizeof(block_q4_K) * nb / 2;
- device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
uint16_t sc16[4];
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -2031,8 +2705,8 @@ kernel void kernel_mul_mv_q4_K_f32(
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
}
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
device const half * dh = &x[ib].d;
for (int row = 0; row < N_DST; row++) {
@@ -2076,23 +2750,24 @@ kernel void kernel_mul_mv_q4_K_f32(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
}
}
}
#else
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2103,12 +2778,18 @@ kernel void kernel_mul_mv_q4_K_f32(
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
- const int r2 = tgpig.z;
+ const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
- const uint offset0 = r2/gqa*(nb*ne0);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
float yl[8];
float yh[8];
float sumf[N_DST]={0.f}, all_sum;
@@ -2164,13 +2845,14 @@ kernel void kernel_mul_mv_q4_K_f32(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
+ dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
}
}
}
#endif
-kernel void kernel_mul_mv_q5_K_f32(
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@@ -2179,23 +2861,49 @@ kernel void kernel_mul_mv_q5_K_f32(
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q5_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- const int r2 = tgpig.z;
+ const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
- const uint offset0 = r2/gqa*(nb*ne0);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float sumf[2]={0.f};
@@ -2211,15 +2919,15 @@ kernel void kernel_mul_mv_q5_K_f32(
const int tid = tiisg/4;
const int ix = tiisg%4;
- const int im = tid/4;
+ const int iq = tid/4;
const int ir = tid%4;
const int n = 8;
const int l0 = n*ir;
- const int q_offset = 32*im + l0;
- const int y_offset = 64*im + l0;
+ const int q_offset = 32*iq + l0;
+ const int y_offset = 64*iq + l0;
- const uint8_t hm1 = 1u << (2*im);
+ const uint8_t hm1 = 1u << (2*iq);
const uint8_t hm2 = hm1 << 1;
const uint8_t hm3 = hm1 << 4;
const uint8_t hm4 = hm2 << 4;
@@ -2234,7 +2942,7 @@ kernel void kernel_mul_mv_q5_K_f32(
device const uint8_t * q1 = x[i].qs + q_offset;
device const uint8_t * qh = x[i].qh + l0;
device const half * dh = &x[i].d;
- device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
device const float * y2 = y1 + 128;
float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -2290,7 +2998,7 @@ kernel void kernel_mul_mv_q5_K_f32(
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
const int ix = tiisg%8;
- const int im = il/8; // 0, 0, 1, 1
+ const int iq = il/8; // 0, 0, 1, 1
const int in = il%8; // 0, 4, 0, 4
device const float * y = yy + ix*QK_K + il;
@@ -2315,7 +3023,7 @@ kernel void kernel_mul_mv_q5_K_f32(
float2 acc = {0.f, 0.f};
for (int l = 0; l < 4; ++l) {
- const uint8_t hl = h[l] >> im;
+ const uint8_t hl = h[l] >> iq;
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -2337,27 +3045,48 @@ kernel void kernel_mul_mv_q5_K_f32(
for (int row = 0; row < 2; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
}
}
+}
+
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
}
-kernel void kernel_mul_mv_q6_K_f32(
+void kernel_mul_mv_q6_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
@@ -2368,12 +3097,17 @@ kernel void kernel_mul_mv_q6_K_f32(
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- const int r2 = tgpig.z;
+ const int im = tgpig.z;
const int row = 2 * r0 + sgitg;
- const uint offset0 = r2/gqa*(nb*ne0);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float sumf = 0;
@@ -2439,10 +3173,31 @@ kernel void kernel_mul_mv_q6_K_f32(
const float tot = simd_sum(sumf);
if (tiisg == 0) {
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
}
}
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
//============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
@@ -2717,22 +3472,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
template
kernel void kernel_get_rows(
device const void * src0,
- device const int * src1,
+ device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
constant uint64_t & nb1,
- uint tgpig[[threadgroup_position_in_grid]],
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
- uint tptg[[threads_per_threadgroup]]) {
- const int i = tgpig;
- const int r = ((device int32_t *) src1)[i];
+ uint3 tptg [[threads_per_threadgroup]]) {
+ //const int64_t i = tgpig;
+ //const int64_t r = ((device int32_t *) src1)[i];
+
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
float4x4 temp;
dequantize_func(
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+ }
+}
+
+kernel void kernel_get_rows_f32(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ }
+}
+
+kernel void kernel_get_rows_f16(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
@@ -2749,24 +3572,25 @@ kernel void kernel_get_rows(
// each block_q contains 16*nl weights
template
-kernel void kernel_mul_mm(device const uchar * src0,
- device const uchar * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant int64_t & nb01,
- constant int64_t & nb02,
- constant int64_t & ne12,
- constant int64_t & nb10,
- constant int64_t & nb11,
- constant int64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & gqa,
- threadgroup uchar * shared_memory [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+void kernel_mul_mm_impl(device const uchar * src0,
+ device const uchar * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & ne12,
+ constant int64_t & nb10,
+ constant int64_t & nb11,
+ constant int64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2792,7 +3616,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
short il = (tiitg % THREAD_PER_ROW);
- uint offset0 = im/gqa*nb02;
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -2876,17 +3703,137 @@ kernel void kernel_mul_mm(device const uchar * src0,
}
}
+template
+kernel void kernel_mul_mm(device const uchar * src0,
+ device const uchar * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & ne12,
+ constant int64_t & nb10,
+ constant int64_t & nb11,
+ constant int64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mm_impl(
+ src0,
+ src1,
+ dst,
+ ne00,
+ ne02,
+ nb01,
+ nb02,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ shared_memory,
+ tgpig,
+ tiitg,
+ sgitg);
+}
+
+template
+kernel void kernel_mul_mm_id(
+ device const uchar * ids,
+ device const uchar * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant int64_t & nb10,
+ constant int64_t & nb11,
+ constant int64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const uchar * src00,
+ device const uchar * src01,
+ device const uchar * src02,
+ device const uchar * src03,
+ device const uchar * src04,
+ device const uchar * src05,
+ device const uchar * src06,
+ device const uchar * src07,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mm_impl(
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
+ ne00,
+ ne02,
+ nb01,
+ nb02,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ shared_memory,
+ tgpig,
+ tiitg,
+ sgitg);
+}
+
#if QK_K == 256
#define QK_NL 16
#else
#define QK_NL 4
#endif
-typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
- constant uint64_t &, constant uint64_t &, uint, uint, uint);
+//
+// get rows
+//
-template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows;
-template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows;
+typedef void (get_rows_t)(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3, uint, uint3);
+
+//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows;
+//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows;
@@ -2898,6 +3845,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows;
+//
+// matrix-matrix multiplication
+//
+
typedef void (mat_mm_t)(
device const uchar * src0,
device const uchar * src1,
@@ -2912,8 +3863,10 @@ typedef void (mat_mm_t)(
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & gqa,
- threadgroup uchar *, uint3, uint, uint);
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup uchar *,
+ uint3, uint, uint);
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm;
@@ -2927,3 +3880,823 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+
+//
+// indirect matrix-matrix multiplication
+//
+
+typedef void (mat_mm_id_t)(
+ device const uchar * ids,
+ device const uchar * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant int64_t & nb10,
+ constant int64_t & nb11,
+ constant int64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const uchar * src00,
+ device const uchar * src01,
+ device const uchar * src02,
+ device const uchar * src03,
+ device const uchar * src04,
+ device const uchar * src05,
+ device const uchar * src06,
+ device const uchar * src07,
+ threadgroup uchar *,
+ uint3, uint, uint);
+
+template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+
+//
+// matrix-vector multiplication
+//
+
+[[host_name("kernel_mul_mv_id_f32_f32")]]
+kernel void kernel_mul_mv_id_f32_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_f32_f32_impl(
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_f16_f32")]]
+kernel void kernel_mul_mv_id_f16_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_f16_f32_impl(
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_q8_0_f32")]]
+kernel void kernel_mul_mv_id_q8_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q8_0_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_0_f32")]]
+kernel void kernel_mul_mv_id_q4_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_1_f32")]]
+kernel void kernel_mul_mv_id_q4_1_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_0_f32")]]
+kernel void kernel_mul_mv_id_q5_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_1_f32")]]
+kernel void kernel_mul_mv_id_q5_1_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q2_K_f32")]]
+kernel void kernel_mul_mv_id_q2_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q2_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q3_K_f32")]]
+kernel void kernel_mul_mv_id_q3_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q3_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_K_f32")]]
+kernel void kernel_mul_mv_id_q4_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q4_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_K_f32")]]
+kernel void kernel_mul_mv_id_q5_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q5_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q6_K_f32")]]
+kernel void kernel_mul_mv_id_q6_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q6_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
diff --git a/LLama/runtimes/deps/osx-arm64/libllama.dylib b/LLama/runtimes/deps/osx-arm64/libllama.dylib
index df57f7dfa..712a0be49 100644
Binary files a/LLama/runtimes/deps/osx-arm64/libllama.dylib and b/LLama/runtimes/deps/osx-arm64/libllama.dylib differ
diff --git a/LLama/runtimes/deps/osx-x64/libllama.dylib b/LLama/runtimes/deps/osx-x64/libllama.dylib
index ee6f29b47..c976111ab 100644
Binary files a/LLama/runtimes/deps/osx-x64/libllama.dylib and b/LLama/runtimes/deps/osx-x64/libllama.dylib differ