Skip to content

Commit b62c809

Browse files
committed
Fix recall when all vectors score the same
It is possible to insert non-equal vectors which score the same by the provided similarity measure. E.g. vectors (0.1, 0.1), (0.2, 0.2), (0.3, 0.3) all are really the same point under cosine metric and any pair of those would score 1.0 similarity. This edge case caused some serious issues with graph connectivity and the queries returned at most 33 nodes, even if for large graphs. This PR fixes it by improving fairness of node selection when nodes score the same. This is achieved by a small modification to how node ids are encoded in the NodeQueue. When nodes score the same, they were compared by node ids, which always preferred the nodes added earlier. That created a huge bias. If we reverse the bits of node ids, now this shuffles their order and breaks the systematic bias towards the older nodes. Another part of the fix is making sure nodes with the same score don't block backlinks to be formed. By placing new neighbours before the neighbours with the same score, we're giving them a chance to be linked to. It will drop some other neighbor from the list, but considering it was present in the graph for some time already, it is much more likely to be already well connected.
1 parent 6517870 commit b62c809

File tree

4 files changed

+58
-13
lines changed

4 files changed

+58
-13
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.Arrays;
3434

3535
import static java.lang.Math.min;
36+
import static java.lang.Math.random;
3637

3738
/**
3839
* NodeArray encodes nodeids and their scores relative to some other element
@@ -173,6 +174,7 @@ public int insertSorted(int newNode, float newScore) {
173174
if (size == nodes.length) {
174175
growArrays();
175176
}
177+
176178
int insertionPoint = descSortFindRightMostInsertionPoint(newScore);
177179
if (duplicateExistsNear(insertionPoint, newNode, newScore)) {
178180
return -1;
@@ -287,7 +289,7 @@ protected final int descSortFindRightMostInsertionPoint(float newScore) {
287289
int end = size - 1;
288290
while (start <= end) {
289291
int mid = (start + end) / 2;
290-
if (scores[mid] < newScore) end = mid - 1;
292+
if (scores[mid] <= newScore) end = mid - 1;
291293
else start = mid + 1;
292294
}
293295
return start;

jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,15 @@ public void pushAll(NodeScoreIterator nodeScoreIterator, int count) {
124124
private long encode(int node, float score) {
125125
assert node >= 0 : node;
126126
return order.apply(
127-
(((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
127+
(((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & Integer.reverse(node)));
128128
}
129129

130130
private float decodeScore(long heapValue) {
131131
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
132132
}
133133

134134
private int decodeNodeId(long heapValue) {
135-
return (int) ~(order.apply(heapValue));
135+
return Integer.reverse((int) order.apply(heapValue));
136136
}
137137

138138
/** Removes the top element and returns its node id. */

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,48 +60,50 @@ public void testScoresDescOrder() {
6060

6161
neighbors.insertSorted(4, 1f);
6262
assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors);
63-
assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
63+
assertNodesEqual(new int[] {4, 0, 3, 1}, neighbors);
6464

6565
neighbors.insertSorted(5, 1.1f);
6666
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors);
67-
assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
67+
assertNodesEqual(new int[] {5, 4, 0, 3, 1}, neighbors);
6868

6969
neighbors.insertSorted(6, 0.8f);
7070
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors);
71-
assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
71+
assertNodesEqual(new int[] {5, 4, 0, 3, 6, 1}, neighbors);
7272

7373
neighbors.insertSorted(7, 0.8f);
7474
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
75-
assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
75+
assertNodesEqual(new int[] {5, 4, 0, 3, 7, 6, 1}, neighbors);
7676

7777
neighbors.removeIndex(2);
7878
assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
79-
assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
79+
assertNodesEqual(new int[] {5, 4, 3, 7, 6, 1}, neighbors);
8080

8181
neighbors.removeIndex(0);
8282
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
83-
assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
83+
assertNodesEqual(new int[] {4, 3, 7, 6, 1}, neighbors);
8484

8585
neighbors.removeIndex(4);
8686
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors);
87-
assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
87+
assertNodesEqual(new int[] {4, 3, 7, 6}, neighbors);
8888

8989
neighbors.removeLast();
9090
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
91-
assertNodesEqual(new int[] {0, 3, 1}, neighbors);
91+
assertNodesEqual(new int[] {4, 3, 7}, neighbors);
9292

9393
neighbors.insertSorted(8, 0.9f);
9494
assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors);
95-
assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
95+
assertNodesEqual(new int[] {4, 8, 3, 7}, neighbors);
9696
}
9797

9898
private void assertScoresEqual(float[] scores, NodeArray neighbors) {
99+
assertEquals(scores.length, neighbors.size(), "Number of scores differs");
99100
for (int i = 0; i < scores.length; i++) {
100101
assertEquals(scores[i], neighbors.getScore(i), 0.01f);
101102
}
102103
}
103104

104105
private void assertNodesEqual(int[] nodes, NodeArray neighbors) {
106+
assertEquals(nodes.length, neighbors.size(), "Number of nodes differs");
105107
for (int i = 0; i < nodes.length; i++) {
106108
assertEquals(nodes[i], neighbors.getNode(i));
107109
}
@@ -181,7 +183,7 @@ public void testNoDuplicatesSameScores() {
181183
cna.insertSorted(3, 10.0f);
182184
cna.insertSorted(1, 10.0f); // This is a duplicate and should be ignored
183185
cna.insertSorted(3, 10.0f); // This is also a duplicate
184-
assertArrayEquals(new int[] {1, 2, 3}, cna.copyDenseNodes());
186+
assertArrayEquals(new int[] {3, 2, 1}, cna.copyDenseNodes());
185187
assertArrayEquals(new float[] {10.0f, 10.0f, 10.0f}, cna.copyDenseScores(), 0.01f);
186188
validateSortedByScore(cna);
187189
}

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,47 @@ public void testZeroCentroid(boolean addHierarchy) {
700700
}
701701
}
702702

703+
@Test
704+
public void testSameScoreWithCosineSimilarity()
705+
{
706+
testSameScoreWithCosineSimilarity(10);
707+
testSameScoreWithCosineSimilarity(20);
708+
testSameScoreWithCosineSimilarity(50);
709+
testSameScoreWithCosineSimilarity(100);
710+
testSameScoreWithCosineSimilarity(200);
711+
testSameScoreWithCosineSimilarity(500);
712+
testSameScoreWithCosineSimilarity(1000);
713+
}
714+
715+
private void testSameScoreWithCosineSimilarity(final int N) {
716+
// Create N vectors which differ in their magnitude but have the same direction, so they would
717+
// all have the exactly same cosine similarity to the query vector.
718+
Random rand = getRandom();
719+
VectorFloat<?>[] vectors = new VectorFloat<?>[N];
720+
for (int i = 0; i < N; i++) {
721+
float x = 0.01f + rand.nextFloat();
722+
vectors[i] = vectorTypeSupport.createFloatVector(new float[]{x, x});
723+
}
724+
MockVectorValues vectorValues = MockVectorValues.fromValues(vectors);
725+
726+
similarityFunction = VectorSimilarityFunction.COSINE;
727+
GraphIndexBuilder builder = new GraphIndexBuilder(vectorValues, similarityFunction, 10, 20, 1.0f, 1.0f, false);
728+
OnHeapGraphIndex graph = builder.build(vectorValues);
729+
730+
VectorFloat<?> query = vectorTypeSupport.createFloatVector(new float[]{0.5f, 0.5f});
731+
SearchResult result = GraphSearcher.search(query, N, vectorValues, similarityFunction, graph, Bits.ALL);
732+
733+
// In perfect world, we should return all N vectors, but this is hard to guarantee considering
734+
// the graph is built with a semi-randomized algorithm. And this is an edge case already, so
735+
// we don't want to make the graph building algorithm more complex or less performant in order to satisfy
736+
// this test. In a typical scenario we'll have many more vectors in the graph than the query limit,
737+
// so missing some results is fine. We'd fall back to brute force search anyway if limit
738+
// is the same order of magnitude as the graph size.
739+
int minExpected = (int) (N * 0.5);
740+
assertTrue("Should return almost all vectors, expected at least: " + minExpected + ", got: " + result.getNodes().length,
741+
result.getNodes().length >= minExpected);
742+
}
743+
703744
/**
704745
* Returns vectors evenly distributed around the upper unit semicircle.
705746
*/

0 commit comments

Comments
 (0)