Skip to content

Commit aaf4136

Browse files
committed
Fix recall when all vectors score the same
It is possible to insert non-equal vectors which score the same by the provided similarity measure. E.g. vectors (0.1, 0.1), (0.2, 0.2), (0.3, 0.3) all are really the same point under cosine metric and any pair of those would score 1.0 similarity. This edge case caused some serious issues with graph connectivity and the queries returned at most 33 nodes, even if for large graphs. This PR fixes it by improving fairness of node selection when nodes score the same. This is achieved by a small modification to how node ids are encoded in the NodeQueue. When nodes score the same, they were compared by node ids, which always preferred the nodes added earlier. That created a huge bias. If we reverse the bits of node ids, now this shuffles their order and breaks the systematic bias towards the older nodes. Another part of the fix is making sure nodes with the same score don't block backlinks to be formed. By placing new neighbours before the neighbours with the same score, we're giving them a chance to be linked to. It will drop some other neighbor from the list, but considering it was present in the graph for some time already, it is much more likely to be already well connected.
1 parent 6517870 commit aaf4136

File tree

4 files changed

+63
-20
lines changed

4 files changed

+63
-20
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ public int insertSorted(int newNode, float newScore) {
173173
if (size == nodes.length) {
174174
growArrays();
175175
}
176-
int insertionPoint = descSortFindRightMostInsertionPoint(newScore);
176+
177+
int insertionPoint = descSortFindLeftMostInsertionPoint(newScore);
177178
if (duplicateExistsNear(insertionPoint, newNode, newScore)) {
178179
return -1;
179180
}
@@ -282,12 +283,12 @@ public String toString() {
282283
return sb.toString();
283284
}
284285

285-
protected final int descSortFindRightMostInsertionPoint(float newScore) {
286+
protected final int descSortFindLeftMostInsertionPoint(float newScore) {
286287
int start = 0;
287288
int end = size - 1;
288289
while (start <= end) {
289290
int mid = (start + end) / 2;
290-
if (scores[mid] < newScore) end = mid - 1;
291+
if (scores[mid] <= newScore) end = mid - 1;
291292
else start = mid + 1;
292293
}
293294
return start;

jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,21 @@ public void pushAll(NodeScoreIterator nodeScoreIterator, int count) {
100100
}
101101

102102
/**
103-
* Encodes the node ID and its similarity score as long. If two scores are equals,
104-
* the smaller node ID wins.
103+
* Encodes the node ID and its similarity score as long.
105104
*
106105
* <p>The most significant 32 bits represent the float score, encoded as a sortable int.
107106
*
108107
* <p>The less significant 32 bits represent the node ID.
109108
*
110-
* <p>The bits representing the node ID are complemented to guarantee the win for the smaller node
111-
* ID.
109+
* <p>The bits representing the node ID are reversed to ensure no bias towards smaller or greater IDs
110+
* when scores are equal.
112111
*
113112
* <p>The AND with 0xFFFFFFFFL (a long with first 32 bit as 1) is necessary to obtain a long that
114113
* has
115114
*
116115
* <p>The most significant 32 bits to 0
117116
*
118-
* <p>The less significant 32 bits represent the node ID.
117+
* <p>The less significant 32 bits represent the encoded node ID.
119118
*
120119
* @param node the node ID
121120
* @param score the node score
@@ -124,15 +123,15 @@ public void pushAll(NodeScoreIterator nodeScoreIterator, int count) {
124123
private long encode(int node, float score) {
125124
assert node >= 0 : node;
126125
return order.apply(
127-
(((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
126+
(((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & Integer.reverse(node)));
128127
}
129128

130129
private float decodeScore(long heapValue) {
131130
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
132131
}
133132

134133
private int decodeNodeId(long heapValue) {
135-
return (int) ~(order.apply(heapValue));
134+
return Integer.reverse((int) order.apply(heapValue));
136135
}
137136

138137
/** Removes the top element and returns its node id. */

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,48 +60,50 @@ public void testScoresDescOrder() {
6060

6161
neighbors.insertSorted(4, 1f);
6262
assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors);
63-
assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
63+
assertNodesEqual(new int[] {4, 0, 3, 1}, neighbors);
6464

6565
neighbors.insertSorted(5, 1.1f);
6666
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors);
67-
assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
67+
assertNodesEqual(new int[] {5, 4, 0, 3, 1}, neighbors);
6868

6969
neighbors.insertSorted(6, 0.8f);
7070
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors);
71-
assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
71+
assertNodesEqual(new int[] {5, 4, 0, 3, 6, 1}, neighbors);
7272

7373
neighbors.insertSorted(7, 0.8f);
7474
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
75-
assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
75+
assertNodesEqual(new int[] {5, 4, 0, 3, 7, 6, 1}, neighbors);
7676

7777
neighbors.removeIndex(2);
7878
assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
79-
assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
79+
assertNodesEqual(new int[] {5, 4, 3, 7, 6, 1}, neighbors);
8080

8181
neighbors.removeIndex(0);
8282
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
83-
assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
83+
assertNodesEqual(new int[] {4, 3, 7, 6, 1}, neighbors);
8484

8585
neighbors.removeIndex(4);
8686
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors);
87-
assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
87+
assertNodesEqual(new int[] {4, 3, 7, 6}, neighbors);
8888

8989
neighbors.removeLast();
9090
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
91-
assertNodesEqual(new int[] {0, 3, 1}, neighbors);
91+
assertNodesEqual(new int[] {4, 3, 7}, neighbors);
9292

9393
neighbors.insertSorted(8, 0.9f);
9494
assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors);
95-
assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
95+
assertNodesEqual(new int[] {4, 8, 3, 7}, neighbors);
9696
}
9797

9898
private void assertScoresEqual(float[] scores, NodeArray neighbors) {
99+
assertEquals(scores.length, neighbors.size(), "Number of scores differs");
99100
for (int i = 0; i < scores.length; i++) {
100101
assertEquals(scores[i], neighbors.getScore(i), 0.01f);
101102
}
102103
}
103104

104105
private void assertNodesEqual(int[] nodes, NodeArray neighbors) {
106+
assertEquals(nodes.length, neighbors.size(), "Number of nodes differs");
105107
for (int i = 0; i < nodes.length; i++) {
106108
assertEquals(nodes[i], neighbors.getNode(i));
107109
}
@@ -181,7 +183,7 @@ public void testNoDuplicatesSameScores() {
181183
cna.insertSorted(3, 10.0f);
182184
cna.insertSorted(1, 10.0f); // This is a duplicate and should be ignored
183185
cna.insertSorted(3, 10.0f); // This is also a duplicate
184-
assertArrayEquals(new int[] {1, 2, 3}, cna.copyDenseNodes());
186+
assertArrayEquals(new int[] {3, 2, 1}, cna.copyDenseNodes());
185187
assertArrayEquals(new float[] {10.0f, 10.0f, 10.0f}, cna.copyDenseScores(), 0.01f);
186188
validateSortedByScore(cna);
187189
}

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,47 @@ public void testZeroCentroid(boolean addHierarchy) {
700700
}
701701
}
702702

703+
@Test
704+
public void testSameScoreWithCosineSimilarity()
705+
{
706+
testSameScoreWithCosineSimilarity(10);
707+
testSameScoreWithCosineSimilarity(20);
708+
testSameScoreWithCosineSimilarity(50);
709+
testSameScoreWithCosineSimilarity(100);
710+
testSameScoreWithCosineSimilarity(200);
711+
testSameScoreWithCosineSimilarity(500);
712+
testSameScoreWithCosineSimilarity(1000);
713+
}
714+
715+
private void testSameScoreWithCosineSimilarity(final int N) {
716+
// Create N vectors which differ in their magnitude but have the same direction, so they would
717+
// all have the exactly same cosine similarity to the query vector.
718+
Random rand = getRandom();
719+
VectorFloat<?>[] vectors = new VectorFloat<?>[N];
720+
for (int i = 0; i < N; i++) {
721+
float x = 0.01f + rand.nextFloat();
722+
vectors[i] = vectorTypeSupport.createFloatVector(new float[]{x, x});
723+
}
724+
MockVectorValues vectorValues = MockVectorValues.fromValues(vectors);
725+
726+
similarityFunction = VectorSimilarityFunction.COSINE;
727+
GraphIndexBuilder builder = new GraphIndexBuilder(vectorValues, similarityFunction, 10, 20, 1.0f, 1.0f, false);
728+
OnHeapGraphIndex graph = builder.build(vectorValues);
729+
730+
VectorFloat<?> query = vectorTypeSupport.createFloatVector(new float[]{0.5f, 0.5f});
731+
SearchResult result = GraphSearcher.search(query, N, vectorValues, similarityFunction, graph, Bits.ALL);
732+
733+
// In perfect world, we should return all N vectors, but this is hard to guarantee considering
734+
// the graph is built with a semi-randomized algorithm. And this is an edge case already, so
735+
// we don't want to make the graph building algorithm more complex or less performant in order to satisfy
736+
// this test. In a typical scenario we'll have many more vectors in the graph than the query limit,
737+
// so missing some results is fine. We'd fall back to brute force search anyway if limit
738+
// is the same order of magnitude as the graph size.
739+
int minExpected = (int) (N * 0.5);
740+
assertTrue("Should return almost all vectors, expected at least: " + minExpected + ", got: " + result.getNodes().length,
741+
result.getNodes().length >= minExpected);
742+
}
743+
703744
/**
704745
* Returns vectors evenly distributed around the upper unit semicircle.
705746
*/

0 commit comments

Comments
 (0)