Skip to content

Commit 4bc90f4

Browse files
authored
Merge pull request #1248 from krisbiradar/add-support-for-gemma-3n
Add support for gemma 3n
2 parents 8afd3eb + 55a7aeb commit 4bc90f4

18 files changed

+176
-32
lines changed

LLama.Unittest/LLamaContextTests.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ public LLamaContextTests()
1313
{
1414
var @params = new ModelParams(Constants.GenerativeModelPath2)
1515
{
16-
ContextSize = 128,
16+
ContextSize = 512,
1717
BatchSize = 8,
1818
UBatchSize = 8,
19-
SeqMax = 1,
2019
VocabOnly = false,
2120
GpuLayerCount = Constants.CIGpuLayerCount,
2221
};
@@ -33,7 +32,7 @@ public void Dispose()
3332
[Fact]
3433
public void CheckProperties()
3534
{
36-
Assert.Equal(128u, _context.ContextSize);
35+
Assert.Equal(_context.NativeHandle.MaxSeq * 256, _context.ContextSize);
3736
Assert.Equal(960, _context.EmbeddingSize);
3837
Assert.Equal(49152, _context.Vocab.Count);
3938
}

LLama.Unittest/LLamaContextWithCustomLoggerTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public LLamaContextWithCustomLoggerTests()
3030
{
3131
var @params = new ModelParams(Constants.GenerativeModelPath2)
3232
{
33-
ContextSize = 128,
33+
ContextSize = 512,
3434
GpuLayerCount = Constants.CIGpuLayerCount,
3535
};
3636

@@ -55,7 +55,7 @@ public void Dispose()
5555
[Fact]
5656
public void CheckProperties()
5757
{
58-
Assert.Equal(128u, _context.ContextSize);
58+
Assert.Equal(_context.NativeHandle.MaxSeq * 256, _context.ContextSize);
5959
Assert.Equal(960, _context.EmbeddingSize);
6060
Assert.Equal(49152, _context.Vocab.Count);
6161
}

LLama.Unittest/LLamaRerankerTests.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
2020
ContextSize = 0,
2121
PoolingType = LLamaPoolingType.Rank,
2222
GpuLayerCount = Constants.CIGpuLayerCount,
23-
2423
};
2524
using var weights = LLamaWeights.LoadFromFile(@params);
2625
_reranker = new LLamaReranker(weights, @params);

LLama.Unittest/SamplingTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public void BatchedSampling()
104104
}
105105
}
106106

107-
// Add " repeat" and test whether next tokens will be "this phrase forever.".
107+
// Add " repeat" and test whether next tokens will be "this phrase forever."
108108
for (int i = 0; i < 4; i++)
109109
{
110110
for (int b = 0; b < batch_count; b++)

LLama.Web/Common/ModelOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public class ModelOptions
102102
public bool NoKqvOffload { get; set; }
103103

104104
/// <inheritdoc />
105-
public bool FlashAttention { get; set; }
105+
public bool? FlashAttention { get; set; }
106106

107107
/// <inheritdoc />
108108
public Encoding Encoding { get; set; } = Encoding.UTF8;

LLama/Abstractions/IContextParams.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ public interface IContextParams
106106
/// <summary>
107107
/// Whether to use flash attention
108108
/// </summary>
109-
bool FlashAttention { get; }
110-
109+
bool? FlashAttention { get; }
110+
111111
/// <summary>
112112
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt;= 0 to disable (default)
113113
/// </summary>

LLama/Common/ModelParams.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using LLama.Abstractions;
23
using System.Text;
34
using System.Text.Json.Serialization;
@@ -95,12 +96,12 @@ public record ModelParams
9596

9697
/// <inheritdoc />
9798
public bool NoKqvOffload { get; set; }
98-
99+
99100
/// <inheritdoc />
100-
101-
public bool FlashAttention { get; set; }
101+
public bool? FlashAttention { get; set; }
102102

103103
/// <inheritdoc />
104+
[Obsolete]
104105
public float? DefragThreshold { get; set; }
105106

106107
/// <inheritdoc />

LLama/Extensions/IContextParamsExtensions.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
3737
result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f;
3838
result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
3939
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.Unspecified;
40-
40+
4141
result.defrag_threshold = @params.DefragThreshold ?? -1;
4242

4343
result.cb_eval = IntPtr.Zero;
@@ -49,9 +49,16 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
4949
result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
5050
result.type_v = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
5151
result.offload_kqv = !@params.NoKqvOffload;
52-
result.flash_attention = @params.FlashAttention;
5352
result.llama_pooling_type = @params.PoolingType;
5453
result.attention_type = @params.AttentionType;
54+
result.llama_flash_attn_type = @params.FlashAttention switch
55+
{
56+
true => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_ENABLED,
57+
false => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_DISABLED,
58+
null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO
59+
};
60+
result.kv_unified = true;
61+
result.n_seq_max = (uint)Math.Min(Math.Max(10,result.n_ctx/8),256);
5562

5663
result.n_threads = Threads(@params.Threads);
5764
result.n_threads_batch = Threads(@params.BatchThreads);

LLama/LLamaSharp.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
</ItemGroup>
5858

5959
<PropertyGroup>
60-
<BinaryReleaseId>11dd5a44eb180e</BinaryReleaseId>
60+
<BinaryReleaseId>86587da</BinaryReleaseId>
6161
</PropertyGroup>
6262

6363
<PropertyGroup>

LLama/Native/LLamaContextParams.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ public struct LLamaContextParams
6464
/// Attention type to use for embeddings
6565
/// </summary>
6666
public LLamaAttentionType attention_type;
67+
68+
/// <summary>
69+
/// when to enable Flash Attention
70+
/// </summary>
71+
public LLamaFlashAttentionType llama_flash_attn_type;
6772

6873
/// <summary>
6974
/// RoPE base frequency, 0 = from model

0 commit comments

Comments
 (0)