Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ public sealed record MetadataOverride
/// <summary>
/// Get the key being overriden by this override
/// </summary>
public string Key { get; init; }
public string Key { get; }

internal LLamaModelKvOverrideType Type { get; }

Expand Down
9 changes: 2 additions & 7 deletions LLama/Common/ChatHistory.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Collections.Generic;
using System.IO;
using System.Text.Json;
using System.Text.Json.Serialization;

Expand Down Expand Up @@ -37,6 +36,7 @@ public enum AuthorRole
/// </summary>
public class ChatHistory
{
private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true };

/// <summary>
/// Chat message representation
Expand Down Expand Up @@ -96,12 +96,7 @@ public void AddMessage(AuthorRole authorRole, string content)
/// <returns></returns>
public string ToJson()
{
return JsonSerializer.Serialize(
this,
new JsonSerializerOptions()
{
WriteIndented = true
});
return JsonSerializer.Serialize(this, _jsonOptions);
}

/// <summary>
Expand Down
1 change: 0 additions & 1 deletion LLama/Common/FixedSizeQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using LLama.Extensions;

namespace LLama.Common
{
Expand Down
2 changes: 2 additions & 0 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ public record InferenceParams
/// number of tokens to keep from initial prompt
/// </summary>
public int TokensKeep { get; set; } = 0;

/// <summary>
/// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
/// until it complete.
/// </summary>
public int MaxTokens { get; set; } = -1;

/// <summary>
/// logit bias for specific tokens
/// </summary>
Expand Down
1 change: 1 addition & 0 deletions LLama/Extensions/DictionaryExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public static TValue GetValueOrDefault<TKey, TValue>(this IReadOnlyDictionary<TK

internal static TValue GetValueOrDefaultImpl<TKey, TValue>(IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue)
{
// ReSharper disable once CanSimplifyDictionaryTryGetValueWithGetValueOrDefault (this is a shim for that method!)
return dictionary.TryGetValue(key, out var value) ? value : defaultValue;
}
}
Expand Down
8 changes: 7 additions & 1 deletion LLama/Grammars/Grammar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public sealed class Grammar
/// <summary>
/// Index of the initial rule to start from
/// </summary>
public ulong StartRuleIndex { get; set; }
public ulong StartRuleIndex { get; }

/// <summary>
/// The rules which make up this grammar
Expand Down Expand Up @@ -121,6 +121,12 @@ private void PrintRule(StringBuilder output, GrammarRule rule)
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
break;

case LLamaGrammarElementType.END:
case LLamaGrammarElementType.ALT:
case LLamaGrammarElementType.RULE_REF:
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
default:
output.Append("] ");
break;
Expand Down
11 changes: 1 addition & 10 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public sealed class LLamaContext
/// <summary>
/// The context params set for this context
/// </summary>
public IContextParams Params { get; set; }
public IContextParams Params { get; }

/// <summary>
/// The native handle, which is used to be passed to the native APIs
Expand All @@ -56,15 +56,6 @@ public sealed class LLamaContext
/// </summary>
public Encoding Encoding { get; }

internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
{
Params = @params;

_logger = logger;
Encoding = @params.Encoding;
NativeHandle = nativeContext;
}

/// <summary>
/// Create a new LLamaContext for the given LLamaWeights
/// </summary>
Expand Down
18 changes: 8 additions & 10 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@ namespace LLama
public sealed class LLamaEmbedder
: IDisposable
{
private readonly LLamaContext _ctx;

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;
public int EmbeddingSize => Context.EmbeddingSize;

/// <summary>
/// LLama Context
/// </summary>
public LLamaContext Context => this._ctx;
public LLamaContext Context { get; }

/// <summary>
/// Create a new embedder, using the given LLamaWeights
Expand All @@ -33,7 +31,7 @@ public sealed class LLamaEmbedder
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
@params.EmbeddingMode = true;
_ctx = weights.CreateContext(@params, logger);
Context = weights.CreateContext(@params, logger);
}

/// <summary>
Expand Down Expand Up @@ -72,20 +70,20 @@ public float[] GetEmbeddings(string text)
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text, bool addBos)
{
var embed_inp_array = _ctx.Tokenize(text, addBos);
var embed_inp_array = Context.Tokenize(text, addBos);

// TODO(Rinne): deal with log of prompt

if (embed_inp_array.Length > 0)
_ctx.Eval(embed_inp_array, 0);
Context.Eval(embed_inp_array, 0);

unsafe
{
var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle);
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();

return new Span<float>(embeddings, EmbeddingSize).ToArray();
return embeddings.ToArray();
}
}

Expand All @@ -94,7 +92,7 @@ public float[] GetEmbeddings(string text, bool addBos)
/// </summary>
public void Dispose()
{
_ctx.Dispose();
Context.Dispose();
}

}
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaWeights.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public sealed class LLamaWeights
/// </summary>
public IReadOnlyDictionary<string, string> Metadata { get; set; }

internal LLamaWeights(SafeLlamaModelHandle weights)
private LLamaWeights(SafeLlamaModelHandle weights)
{
NativeHandle = weights;
Metadata = weights.ReadMetadata();
Expand Down
4 changes: 2 additions & 2 deletions LLama/Native/LLamaKvCacheView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public struct LLamaKvCacheViewCell
/// May be negative if the cell is not populated.
/// </summary>
public LLamaPos pos;
};
}

/// <summary>
/// An updateable view of the KV cache (llama_kv_cache_view)
Expand Down Expand Up @@ -130,7 +130,7 @@ public ref LLamaKvCacheView GetView()
}
}

partial class NativeApi
public static partial class NativeApi
{
/// <summary>
/// Create an empty KV cache view. (use only for debugging purposes)
Expand Down
4 changes: 2 additions & 2 deletions LLama/Native/LLamaPos.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace LLama.Native;
/// Indicates position in a sequence
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaPos
public record struct LLamaPos
{
/// <summary>
/// The raw value
Expand All @@ -17,7 +17,7 @@ public struct LLamaPos
/// Create a new LLamaPos
/// </summary>
/// <param name="value"></param>
public LLamaPos(int value)
private LLamaPos(int value)
{
Value = value;
}
Expand Down
4 changes: 2 additions & 2 deletions LLama/Native/LLamaSeqId.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace LLama.Native;
/// ID for a sequence in a batch
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaSeqId
public record struct LLamaSeqId
{
/// <summary>
/// The raw value
Expand All @@ -17,7 +17,7 @@ public struct LLamaSeqId
/// Create a new LLamaSeqId
/// </summary>
/// <param name="value"></param>
public LLamaSeqId(int value)
private LLamaSeqId(int value)
{
Value = value;
}
Expand Down
12 changes: 12 additions & 0 deletions LLama/Native/LLamaTokenType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace LLama.Native;

public enum LLamaTokenType
{
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
LLAMA_TOKEN_TYPE_NORMAL = 1,
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
LLAMA_TOKEN_TYPE_CONTROL = 3,
LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
LLAMA_TOKEN_TYPE_UNUSED = 5,
LLAMA_TOKEN_TYPE_BYTE = 6,
}
2 changes: 1 addition & 1 deletion LLama/Native/NativeApi.BeamSearch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace LLama.Native;

public partial class NativeApi
public static partial class NativeApi
{
/// <summary>
/// Type of pointer to the beam_search_callback function.
Expand Down
4 changes: 2 additions & 2 deletions LLama/Native/NativeApi.Grammar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace LLama.Native
{
using llama_token = Int32;

public unsafe partial class NativeApi
public static partial class NativeApi
{
/// <summary>
/// Create a new grammar from the given set of grammar rules
Expand All @@ -15,7 +15,7 @@ public unsafe partial class NativeApi
/// <param name="start_rule_index"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);
public static extern unsafe IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);

/// <summary>
/// Free all memory from the given SafeLLamaGrammarHandle
Expand Down
Loading