Skip to content

Commit ce3913d

Browse files
authored
Add BinaryEmbedding (#6398)
* Add BinaryEmbedding Also: - Renames the polymorphic discriminators to conform with typical lingo for these types. - Adds an Embedding.Dimensions virtual property.
1 parent c49594b commit ce3913d

File tree

8 files changed

+270
-28
lines changed

8 files changed

+270
-28
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Buffers;
6+
using System.Collections;
7+
using System.ComponentModel;
8+
using System.Text.Json;
9+
using System.Text.Json.Serialization;
10+
using Microsoft.Shared.Diagnostics;
11+
12+
namespace Microsoft.Extensions.AI;
13+
14+
/// <summary>Represents an embedding composed of a bit vector.</summary>
15+
public sealed class BinaryEmbedding : Embedding
16+
{
17+
/// <summary>The embedding vector this embedding represents.</summary>
18+
private BitArray _vector;
19+
20+
/// <summary>Initializes a new instance of the <see cref="BinaryEmbedding"/> class with the embedding vector.</summary>
21+
/// <param name="vector">The embedding vector this embedding represents.</param>
22+
/// <exception cref="ArgumentNullException"><paramref name="vector"/> is <see langword="null"/>.</exception>
23+
public BinaryEmbedding(BitArray vector)
24+
{
25+
_vector = Throw.IfNull(vector);
26+
}
27+
28+
/// <summary>Gets or sets the embedding vector this embedding represents.</summary>
29+
[JsonConverter(typeof(VectorConverter))]
30+
public BitArray Vector
31+
{
32+
get => _vector;
33+
set => _vector = Throw.IfNull(value);
34+
}
35+
36+
/// <inheritdoc />
37+
[JsonIgnore]
38+
public override int Dimensions => _vector.Length;
39+
40+
/// <summary>Provides a <see cref="JsonConverter{BitArray}"/> for serializing <see cref="BitArray"/> instances.</summary>
41+
[EditorBrowsable(EditorBrowsableState.Never)]
42+
public sealed class VectorConverter : JsonConverter<BitArray>
43+
{
44+
/// <inheritdoc/>
45+
public override BitArray Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
46+
{
47+
_ = Throw.IfNull(typeToConvert);
48+
_ = Throw.IfNull(options);
49+
50+
if (reader.TokenType != JsonTokenType.String)
51+
{
52+
throw new JsonException("Expected string property.");
53+
}
54+
55+
ReadOnlySpan<byte> utf8;
56+
byte[]? tmpArray = null;
57+
if (!reader.HasValueSequence && !reader.ValueIsEscaped)
58+
{
59+
utf8 = reader.ValueSpan;
60+
}
61+
else
62+
{
63+
// This path should be rare.
64+
int length = reader.HasValueSequence ? checked((int)reader.ValueSequence.Length) : reader.ValueSpan.Length;
65+
tmpArray = ArrayPool<byte>.Shared.Rent(length);
66+
utf8 = tmpArray.AsSpan(0, reader.CopyString(tmpArray));
67+
}
68+
69+
BitArray result = new(utf8.Length);
70+
71+
for (int i = 0; i < utf8.Length; i++)
72+
{
73+
result[i] = utf8[i] switch
74+
{
75+
(byte)'0' => false,
76+
(byte)'1' => true,
77+
_ => throw new JsonException("Expected binary character sequence.")
78+
};
79+
}
80+
81+
if (tmpArray is not null)
82+
{
83+
ArrayPool<byte>.Shared.Return(tmpArray);
84+
}
85+
86+
return result;
87+
}
88+
89+
/// <inheritdoc/>
90+
public override void Write(Utf8JsonWriter writer, BitArray value, JsonSerializerOptions options)
91+
{
92+
_ = Throw.IfNull(writer);
93+
_ = Throw.IfNull(value);
94+
_ = Throw.IfNull(options);
95+
96+
int length = value.Length;
97+
98+
byte[] tmpArray = ArrayPool<byte>.Shared.Rent(length);
99+
100+
Span<byte> utf8 = tmpArray.AsSpan(0, length);
101+
for (int i = 0; i < utf8.Length; i++)
102+
{
103+
utf8[i] = value[i] ? (byte)'1' : (byte)'0';
104+
}
105+
106+
writer.WriteStringValue(utf8);
107+
108+
ArrayPool<byte>.Shared.Return(tmpArray);
109+
}
110+
}
111+
}

src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,23 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Diagnostics;
56
using System.Text.Json.Serialization;
67

78
namespace Microsoft.Extensions.AI;
89

910
/// <summary>Represents an embedding generated by a <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>.</summary>
1011
/// <remarks>This base class provides metadata about the embedding. Derived types provide the concrete data contained in the embedding.</remarks>
1112
[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")]
13+
[JsonDerivedType(typeof(BinaryEmbedding), typeDiscriminator: "binary")]
14+
[JsonDerivedType(typeof(Embedding<byte>), typeDiscriminator: "uint8")]
15+
[JsonDerivedType(typeof(Embedding<sbyte>), typeDiscriminator: "int8")]
1216
#if NET
13-
[JsonDerivedType(typeof(Embedding<Half>), typeDiscriminator: "halves")]
17+
[JsonDerivedType(typeof(Embedding<Half>), typeDiscriminator: "float16")]
1418
#endif
15-
[JsonDerivedType(typeof(Embedding<float>), typeDiscriminator: "floats")]
16-
[JsonDerivedType(typeof(Embedding<double>), typeDiscriminator: "doubles")]
17-
[JsonDerivedType(typeof(Embedding<byte>), typeDiscriminator: "bytes")]
18-
[JsonDerivedType(typeof(Embedding<sbyte>), typeDiscriminator: "sbytes")]
19+
[JsonDerivedType(typeof(Embedding<float>), typeDiscriminator: "float32")]
20+
[JsonDerivedType(typeof(Embedding<double>), typeDiscriminator: "float64")]
21+
[DebuggerDisplay("Dimensions = {Dimensions}")]
1922
public class Embedding
2023
{
2124
/// <summary>Initializes a new instance of the <see cref="Embedding"/> class.</summary>
@@ -26,6 +29,13 @@ protected Embedding()
2629
/// <summary>Gets or sets a timestamp at which the embedding was created.</summary>
2730
public DateTimeOffset? CreatedAt { get; set; }
2831

32+
/// <summary>Gets the dimensionality of the embedding vector.</summary>
33+
/// <remarks>
34+
/// This value corresponds to the number of elements in the embedding vector.
35+
/// </remarks>
36+
[JsonIgnore]
37+
public virtual int Dimensions { get; }
38+
2939
/// <summary>Gets or sets the model ID using in the creation of the embedding.</summary>
3040
public string? ModelId { get; set; }
3141

src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Text.Json.Serialization;
56

67
namespace Microsoft.Extensions.AI;
78

@@ -19,4 +20,8 @@ public Embedding(ReadOnlyMemory<T> vector)
1920

2021
/// <summary>Gets or sets the embedding vector this embedding represents.</summary>
2122
public ReadOnlyMemory<T> Vector { get; set; }
23+
24+
/// <inheritdoc />
25+
[JsonIgnore]
26+
public override int Dimensions => Vector.Length;
2227
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections;
6+
using System.Linq;
7+
using System.Text.Json;
8+
using Xunit;
9+
10+
namespace Microsoft.Extensions.AI;
11+
12+
public class BinaryEmbeddingTests
13+
{
14+
[Fact]
15+
public void Ctor_Roundtrips()
16+
{
17+
BitArray vector = new BitArray(new bool[] { false, true, false, true });
18+
19+
BinaryEmbedding e = new(vector);
20+
Assert.Same(vector, e.Vector);
21+
Assert.Null(e.ModelId);
22+
Assert.Null(e.CreatedAt);
23+
Assert.Null(e.AdditionalProperties);
24+
}
25+
26+
[Fact]
27+
public void Properties_Roundtrips()
28+
{
29+
BitArray vector = new BitArray(new bool[] { false, true, false, true });
30+
31+
BinaryEmbedding e = new(vector);
32+
33+
Assert.Same(vector, e.Vector);
34+
BitArray newVector = new BitArray(new bool[] { true, false, true, false });
35+
e.Vector = newVector;
36+
Assert.Same(newVector, e.Vector);
37+
38+
Assert.Null(e.ModelId);
39+
e.ModelId = "text-embedding-3-small";
40+
Assert.Equal("text-embedding-3-small", e.ModelId);
41+
42+
Assert.Null(e.CreatedAt);
43+
DateTimeOffset createdAt = DateTimeOffset.Parse("2022-01-01T00:00:00Z");
44+
e.CreatedAt = createdAt;
45+
Assert.Equal(createdAt, e.CreatedAt);
46+
47+
Assert.Null(e.AdditionalProperties);
48+
AdditionalPropertiesDictionary props = new();
49+
e.AdditionalProperties = props;
50+
Assert.Same(props, e.AdditionalProperties);
51+
}
52+
53+
[Fact]
54+
public void Serialization_Roundtrips()
55+
{
56+
foreach (int length in Enumerable.Range(0, 64).Concat(new[] { 10_000 }))
57+
{
58+
bool[] bools = new bool[length];
59+
Random r = new(42);
60+
for (int i = 0; i < length; i++)
61+
{
62+
bools[i] = r.Next(2) != 0;
63+
}
64+
65+
BitArray vector = new BitArray(bools);
66+
BinaryEmbedding e = new(vector);
67+
68+
string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
69+
Assert.Equal($$"""{"$type":"binary","vector":"{{string.Concat(vector.Cast<bool>().Select(b => b ? '1' : '0'))}}"}""", json);
70+
71+
BinaryEmbedding result = Assert.IsType<BinaryEmbedding>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
72+
Assert.Equal(e.Vector, result.Vector);
73+
}
74+
}
75+
76+
[Fact]
77+
public void Derialization_SupportsEncodedBits()
78+
{
79+
BinaryEmbedding result = Assert.IsType<BinaryEmbedding>(JsonSerializer.Deserialize(
80+
"""{"$type":"binary","vector":"\u0030\u0031\u0030\u0031\u0030\u0031"}""",
81+
TestJsonSerializerContext.Default.Embedding));
82+
83+
Assert.Equal(new BitArray(new[] { false, true, false, true, false, true }), result.Vector);
84+
}
85+
86+
[Theory]
87+
[InlineData("""{"$type":"binary","vector":"\u0030\u0032"}""")]
88+
[InlineData("""{"$type":"binary","vector":"02"}""")]
89+
[InlineData("""{"$type":"binary","vector":" "}""")]
90+
[InlineData("""{"$type":"binary","vector":10101}""")]
91+
public void Derialization_InvalidBinaryEmbedding_Throws(string json)
92+
{
93+
Assert.Throws<JsonException>(() => JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
94+
}
95+
}

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class EmbeddingTests
1414
public void Embedding_Ctor_Roundtrips()
1515
{
1616
float[] floats = [1f, 2f, 3f];
17-
UsageDetails usage = new();
17+
1818
AdditionalPropertiesDictionary props = [];
1919
var createdAt = DateTimeOffset.Parse("2022-01-01T00:00:00Z");
2020
const string Model = "text-embedding-3-small";
@@ -35,6 +35,32 @@ public void Embedding_Ctor_Roundtrips()
3535
Assert.Same(floats, array.Array);
3636
}
3737

38+
[Fact]
39+
public void Embedding_Byte_SerializationRoundtrips()
40+
{
41+
byte[] bytes = [1, 2, 3];
42+
Embedding<byte> e = new(bytes);
43+
44+
string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
45+
Assert.Equal("""{"$type":"uint8","vector":"AQID"}""", json);
46+
47+
Embedding<byte> result = Assert.IsType<Embedding<byte>>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
48+
Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray());
49+
}
50+
51+
[Fact]
52+
public void Embedding_SByte_SerializationRoundtrips()
53+
{
54+
sbyte[] bytes = [1, 2, 3];
55+
Embedding<sbyte> e = new(bytes);
56+
57+
string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
58+
Assert.Equal("""{"$type":"int8","vector":[1,2,3]}""", json);
59+
60+
Embedding<sbyte> result = Assert.IsType<Embedding<sbyte>>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
61+
Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray());
62+
}
63+
3864
#if NET
3965
[Fact]
4066
public void Embedding_Half_SerializationRoundtrips()
@@ -43,7 +69,7 @@ public void Embedding_Half_SerializationRoundtrips()
4369
Embedding<Half> e = new(halfs);
4470

4571
string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
46-
Assert.Equal("""{"$type":"halves","vector":[1,2,3]}""", json);
72+
Assert.Equal("""{"$type":"float16","vector":[1,2,3]}""", json);
4773

4874
Embedding<Half> result = Assert.IsType<Embedding<Half>>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
4975
Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray());
@@ -57,7 +83,7 @@ public void Embedding_Single_SerializationRoundtrips()
5783
Embedding<float> e = new(floats);
5884

5985
string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
60-
Assert.Equal("""{"$type":"floats","vector":[1,2,3]}""", json);
86+
Assert.Equal("""{"$type":"float32","vector":[1,2,3]}""", json);
6187

6288
Embedding<float> result = Assert.IsType<Embedding<float>>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
6389
Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray());
@@ -70,7 +96,7 @@ public void Embedding_Double_SerializationRoundtrips()
7096
Embedding<double> e = new(floats);
7197

7298
string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
73-
Assert.Equal("""{"$type":"doubles","vector":[1,2,3]}""", json);
99+
Assert.Equal("""{"$type":"float64","vector":[1,2,3]}""", json);
74100

75101
Embedding<double> result = Assert.IsType<Embedding<double>>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
76102
Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray());

test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs

Lines changed: 0 additions & 16 deletions
This file was deleted.

test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
#if NET
6+
using System.Collections;
7+
#endif
58
using System.Collections.Generic;
69
using System.Diagnostics;
710
using System.Diagnostics.CodeAnalysis;
@@ -148,7 +151,14 @@ public async Task Quantization_Binary_EmbeddingsCompareSuccessfully()
148151
{
149152
for (int j = 0; j < embeddings.Count; j++)
150153
{
151-
distances[i, j] = TensorPrimitives.HammingBitDistance(embeddings[i].Bits.Span, embeddings[j].Bits.Span);
154+
distances[i, j] = TensorPrimitives.HammingBitDistance<byte>(ToArray(embeddings[i].Vector), ToArray(embeddings[j].Vector));
155+
156+
static byte[] ToArray(BitArray array)
157+
{
158+
byte[] result = new byte[(array.Length + 7) / 8];
159+
array.CopyTo(result, 0);
160+
return result;
161+
}
152162
}
153163
}
154164

0 commit comments

Comments
 (0)