Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ internal void WriteValue(ref LLamaModelMetadataOverride dest)
dest.FloatValue = _valueFloat;
break;
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
dest.BoolValue = _valueBool ? -1 : 0;
dest.BoolValue = _valueBool ? -1L : 0;
break;
default:
throw new ArgumentOutOfRangeException();
Expand Down
20 changes: 20 additions & 0 deletions LLama/Extensions/EncodingExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ namespace LLama.Extensions;
internal static class EncodingExtensions
{
#if NETSTANDARD2_0
public static int GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, Span<byte> output)
{
return GetBytesImpl(encoding, chars, output);
}

public static int GetChars(this Encoding encoding, ReadOnlySpan<byte> bytes, Span<char> output)
{
return GetCharsImpl(encoding, bytes, output);
Expand All @@ -19,6 +24,21 @@ public static int GetCharCount(this Encoding encoding, ReadOnlySpan<byte> bytes)
#error Target framework not supported!
#endif

internal static int GetBytesImpl(Encoding encoding, ReadOnlySpan<char> chars, Span<byte> output)
{
if (chars.Length == 0)
return 0;

unsafe
{
fixed (char* charPtr = chars)
fixed (byte* bytePtr = output)
{
return encoding.GetBytes(charPtr, chars.Length, bytePtr, output.Length);
}
}
}

internal static int GetCharsImpl(Encoding encoding, ReadOnlySpan<byte> bytes, Span<char> output)
{
if (bytes.Length == 0)
Expand Down
38 changes: 21 additions & 17 deletions LLama/Extensions/IModelParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,35 +45,39 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
}
else
{
// Allocate enough space for all the override items
// Allocate enough space for all the override items. Pin it in place so we can safely pass it to llama.cpp
// This is one larger than necessary. The last item indicates the end of the overrides.
var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1];
var overridesPin = overrides.AsMemory().Pin();
unsafe
{
result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer;
result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overrides.AsMemory().Pin()).Pointer;
}

// Convert each item
for (var i = 0; i < @params.MetadataOverrides.Count; i++)
{
var item = @params.MetadataOverrides[i];
var native = new LLamaModelMetadataOverride
{
Tag = item.Type
};

item.WriteValue(ref native);

// Convert key to bytes
unsafe
{
fixed (char* srcKey = item.Key)
// Get the item to convert
var item = @params.MetadataOverrides[i];

// Create the "native" representation to fill in
var native = new LLamaModelMetadataOverride
{
Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128);
}
}
Tag = item.Type
};

// Write the value into the native struct
item.WriteValue(ref native);

overrides[i] = native;
// Convert key chars to bytes
var srcSpan = item.Key.AsSpan();
var dstSpan = new Span<byte>(native.key, 128);
Encoding.UTF8.GetBytes(srcSpan, dstSpan);

// Store it in the array
overrides[i] = native;
}
}
}

Expand Down
14 changes: 10 additions & 4 deletions LLama/Native/LLamaModelMetadataOverride.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,28 @@ public unsafe struct LLamaModelMetadataOverride
public LLamaModelKvOverrideType Tag;

/// <summary>
/// Value, **must** only be used if Tag == LLAMA_KV_OVERRIDE_INT
/// Add 4 bytes of padding, to align the next fields to 8 bytes
/// </summary>
[FieldOffset(132)]
private readonly int PADDING;

/// <summary>
/// Value, **must** only be used if Tag == LLAMA_KV_OVERRIDE_INT
/// </summary>
[FieldOffset(136)]
public long IntValue;

/// <summary>
/// Value, **must** only be used if Tag == LLAMA_KV_OVERRIDE_FLOAT
/// </summary>
[FieldOffset(132)]
[FieldOffset(136)]
public double FloatValue;

/// <summary>
/// Value, **must** only be used if Tag == LLAMA_KV_OVERRIDE_BOOL
/// </summary>
[FieldOffset(132)]
public int BoolValue;
[FieldOffset(136)]
public long BoolValue;
}

/// <summary>
Expand Down