diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index dcec913c5..fac1a8721 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -88,5 +88,35 @@ public void TokenizeEmpty() Assert.Equal(Array.Empty(), tokens); } + + [Fact] + public void SaveLoadState() + { + using var state1 = _context.GetState(); + + var stream = new MemoryStream(); + state1.Save(stream); + + stream.Position = 0; + + using var state2 = LLamaContext.State.Load(stream); + + Assert.Equal(state1.Size, state2.Size); + } + + [Fact] + public async Task SaveLoadStateAsync() + { + using var state1 = _context.GetState(); + + var stream = new MemoryStream(); + await state1.SaveAsync(stream); + + stream.Position = 0; + + using var state2 = await LLamaContext.State.LoadAsync(stream); + + Assert.Equal(state1.Size, state2.Size); + } } } diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 18b22d82e..ca3ef55c0 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -466,16 +466,15 @@ public void Dispose() public class State : SafeLLamaHandleBase { - private readonly nuint _size; /// /// Get the size in bytes of this state object /// - public nuint Size => _size; + public nuint Size { get; } internal State(IntPtr memory, nuint size) : base(memory, true) { - _size = size; + Size = size; } /// @@ -494,7 +493,8 @@ public async Task SaveAsync(Stream stream) UnmanagedMemoryStream from; unsafe { - from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size)); + var length = (long)Size; + from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), length, length, FileAccess.Read); } await from.CopyToAsync(stream); } @@ -508,7 +508,8 @@ public void Save(Stream stream) UnmanagedMemoryStream from; unsafe { - from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size)); + var length = (long)Size; + from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), length, length, FileAccess.Read); } from.CopyTo(stream); } @@ -526,7 +527,8 @@ public static async Task LoadAsync(Stream stream) UnmanagedMemoryStream dest; unsafe { - dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length); + var length = stream.Length; + dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), length, length, FileAccess.Write); } await stream.CopyToAsync(dest); @@ -543,11 +545,13 @@ public static State Load(Stream stream) var memory = Marshal.AllocHGlobal((nint)stream.Length); var state = new State(memory, (nuint)stream.Length); + UnmanagedMemoryStream dest; unsafe { - var dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length); - stream.CopyTo(dest); + var length = stream.Length; + dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), length, length, FileAccess.Write); } + stream.CopyTo(dest); return state; }