Skip to content

Commit 26c0c8b

Browse files
committed
Fix recall when all vectors score the same
It is possible to insert non-equal vectors which score the same by the provided similarity measure. E.g. vectors (0.1, 0.1), (0.2, 0.2), (0.3, 0.3) all are really the same point under cosine metric and any pair of those would score 1.0 similarity. This edge case caused some serious issues with graph connectivity and the queries returned at most 33 nodes, even if for large graphs. This PR fixes it by improving fairness of node selection when nodes score the same. This is achieved by a small modification to how node ids are encoded in the NodeQueue. When nodes score the same, they were compared by node ids, which always preferred the nodes added earlier. That created a huge bias. If we reverse the bits of node ids, now this shuffles their order and breaks the systematic bias towards the older nodes. Another part of the fix is making sure nodes with the same score don't block backlinks to be formed. By placing new neighbours before the neighbours with the same score, we're giving them a chance to be linked to. It will drop some other neighbor from the list, but considering it was present in the graph for some time already, it is much more likely to be already well connected.
1 parent ebcaf8e commit 26c0c8b

File tree

6 files changed

+63
-24
lines changed

6 files changed

+63
-24
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
3131
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
3232
import io.github.jbellis.jvector.vector.types.VectorFloat;
33-
import org.agrona.collections.IntArrayList;
3433
import org.slf4j.Logger;
3534
import org.slf4j.LoggerFactory;
3635

@@ -455,7 +454,7 @@ private void improveConnections(int node) {
455454
var ssp = scoreProvider.searchProviderFor(node);
456455
var bits = new ExcludingBits(node);
457456
try (var gs = searchers.get()) {
458-
gs.initializeInternal(ssp, graph.entry(), bits);
457+
gs.initializeInternal(ssp, graph.entry(), bits, this.rng.nextInt());
459458
var acceptedBits = Bits.intersectionOf(bits, gs.getView().liveNodes());
460459

461460
// Move downward from entry.level to 0
@@ -566,7 +565,7 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) {
566565
if (entry == null) {
567566
result = new SearchResult(new NodeScore[] {}, 0, 0, 0, 0, 0);
568567
} else {
569-
gs.initializeInternal(searchScoreProvider, entry, bits);
568+
gs.initializeInternal(searchScoreProvider, entry, bits, this.rng.nextInt());
570569

571570
// Move downward from entry.level to 1
572571
for (int lvl = entry.level; lvl > 0; lvl--) {

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ protected void internalSearch(SearchScoreProvider scoreProvider,
254254
float threshold,
255255
Bits acceptOrds)
256256
{
257-
initializeInternal(scoreProvider, entry, acceptOrds);
257+
initializeInternal(scoreProvider, entry, acceptOrds, 0);
258258

259259
// Move downward from entry.level to 1
260260
for (int lvl = entry.level; lvl > 0; lvl--) {
@@ -316,7 +316,7 @@ void setEntryPointsFromPreviousLayer() {
316316
approximateResults.clear();
317317
}
318318

319-
void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bits rawAcceptOrds) {
319+
void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bits rawAcceptOrds, int tieBreakerSeed) {
320320
// save search parameters for potential later resume
321321
initializeScoreProvider(scoreProvider);
322322
this.acceptOrds = Bits.intersectionOf(rawAcceptOrds, view.liveNodes());

jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.Arrays;
3434

3535
import static java.lang.Math.min;
36+
import static java.lang.Math.random;
3637

3738
/**
3839
* NodeArray encodes nodeids and their scores relative to some other element
@@ -168,7 +169,7 @@ public void addInOrder(int newNode, float newScore) {
168169
* or -1 if the node already exists in the array (with the same score).
169170
*/
170171
int insertionPoint(int newNode, float newScore) {
171-
int insertionPoint = descSortFindRightMostInsertionPoint(newScore);
172+
int insertionPoint = descSortFindLeftMostInsertionPoint(newScore);
172173
return duplicateExistsNear(insertionPoint, newNode, newScore) ? -1 : insertionPoint;
173174
}
174175

@@ -305,12 +306,12 @@ public String toString() {
305306
return sb.toString();
306307
}
307308

308-
protected final int descSortFindRightMostInsertionPoint(float newScore) {
309+
protected final int descSortFindLeftMostInsertionPoint(float newScore) {
309310
int start = 0;
310311
int end = size - 1;
311312
while (start <= end) {
312313
int mid = (start + end) / 2;
313-
if (scores[mid] < newScore) end = mid - 1;
314+
if (scores[mid] <= newScore) end = mid - 1;
314315
else start = mid + 1;
315316
}
316317
return start;

jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,15 @@ public void pushMany(NodeScoreIterator nodeScoreIterator, int count) {
125125
private long encode(int node, float score) {
126126
assert node >= 0 : node;
127127
return order.apply(
128-
(((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
128+
(((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & Integer.reverse(node)));
129129
}
130130

131131
private float decodeScore(long heapValue) {
132132
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
133133
}
134134

135135
private int decodeNodeId(long heapValue) {
136-
return (int) ~(order.apply(heapValue));
136+
return Integer.reverse((int) order.apply(heapValue));
137137
}
138138

139139
/** Removes the top element and returns its node id. */

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,48 +60,50 @@ public void testScoresDescOrder() {
6060

6161
neighbors.insertSorted(4, 1f);
6262
assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors);
63-
assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
63+
assertNodesEqual(new int[] {4, 0, 3, 1}, neighbors);
6464

6565
neighbors.insertSorted(5, 1.1f);
6666
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors);
67-
assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
67+
assertNodesEqual(new int[] {5, 4, 0, 3, 1}, neighbors);
6868

6969
neighbors.insertSorted(6, 0.8f);
7070
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors);
71-
assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
71+
assertNodesEqual(new int[] {5, 4, 0, 3, 6, 1}, neighbors);
7272

7373
neighbors.insertSorted(7, 0.8f);
7474
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
75-
assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
75+
assertNodesEqual(new int[] {5, 4, 0, 3, 7, 6, 1}, neighbors);
7676

7777
neighbors.removeIndex(2);
7878
assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
79-
assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
79+
assertNodesEqual(new int[] {5, 4, 3, 7, 6, 1}, neighbors);
8080

8181
neighbors.removeIndex(0);
8282
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
83-
assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
83+
assertNodesEqual(new int[] {4, 3, 7, 6, 1}, neighbors);
8484

8585
neighbors.removeIndex(4);
8686
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors);
87-
assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
87+
assertNodesEqual(new int[] {4, 3, 7, 6}, neighbors);
8888

8989
neighbors.removeLast();
9090
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
91-
assertNodesEqual(new int[] {0, 3, 1}, neighbors);
91+
assertNodesEqual(new int[] {4, 3, 7}, neighbors);
9292

9393
neighbors.insertSorted(8, 0.9f);
9494
assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors);
95-
assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
95+
assertNodesEqual(new int[] {4, 8, 3, 7}, neighbors);
9696
}
9797

9898
private void assertScoresEqual(float[] scores, NodeArray neighbors) {
99+
assertEquals(scores.length, neighbors.size(), "Number of scores differs");
99100
for (int i = 0; i < scores.length; i++) {
100101
assertEquals(scores[i], neighbors.getScore(i), 0.01f);
101102
}
102103
}
103104

104105
private void assertNodesEqual(int[] nodes, NodeArray neighbors) {
106+
assertEquals(nodes.length, neighbors.size(), "Number of nodes differs");
105107
for (int i = 0; i < nodes.length; i++) {
106108
assertEquals(nodes[i], neighbors.getNode(i));
107109
}
@@ -181,7 +183,7 @@ public void testNoDuplicatesSameScores() {
181183
cna.insertSorted(3, 10.0f);
182184
cna.insertSorted(1, 10.0f); // This is a duplicate and should be ignored
183185
cna.insertSorted(3, 10.0f); // This is also a duplicate
184-
assertArrayEquals(new int[] {1, 2, 3}, cna.copyDenseNodes());
186+
assertArrayEquals(new int[] {3, 2, 1}, cna.copyDenseNodes());
185187
assertArrayEquals(new float[] {10.0f, 10.0f, 10.0f}, cna.copyDenseScores(), 0.01f);
186188
validateSortedByScore(cna);
187189
}

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929
import io.github.jbellis.jvector.LuceneTestCase;
3030
import io.github.jbellis.jvector.TestUtil;
3131
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
32-
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
33-
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
34-
import io.github.jbellis.jvector.quantization.PQVectors;
3532
import io.github.jbellis.jvector.quantization.ProductQuantization;
3633
import io.github.jbellis.jvector.util.Bits;
3734
import io.github.jbellis.jvector.util.BoundedLongHeap;
@@ -45,7 +42,6 @@
4542

4643
import java.util.Arrays;
4744
import java.util.HashSet;
48-
import java.util.Iterator;
4945
import java.util.List;
5046
import java.util.Random;
5147
import java.util.Set;
@@ -702,6 +698,47 @@ public void testZeroCentroid(boolean addHierarchy) {
702698
}
703699
}
704700

701+
@Test
702+
public void testSameScoreWithCosineSimilarity()
703+
{
704+
testSameScoreWithCosineSimilarity(10);
705+
testSameScoreWithCosineSimilarity(20);
706+
testSameScoreWithCosineSimilarity(50);
707+
testSameScoreWithCosineSimilarity(100);
708+
testSameScoreWithCosineSimilarity(200);
709+
testSameScoreWithCosineSimilarity(500);
710+
testSameScoreWithCosineSimilarity(1000);
711+
}
712+
713+
private void testSameScoreWithCosineSimilarity(final int N) {
714+
// Create N vectors which differ in their magnitude but have the same direction, so they would
715+
// all have the exactly same cosine similarity to the query vector.
716+
Random rand = getRandom();
717+
VectorFloat<?>[] vectors = new VectorFloat<?>[N];
718+
for (int i = 0; i < N; i++) {
719+
float x = 0.01f + rand.nextFloat();
720+
vectors[i] = vectorTypeSupport.createFloatVector(new float[]{x, x});
721+
}
722+
MockVectorValues vectorValues = MockVectorValues.fromValues(vectors);
723+
724+
similarityFunction = VectorSimilarityFunction.COSINE;
725+
GraphIndexBuilder builder = new GraphIndexBuilder(vectorValues, similarityFunction, 10, 20, 1.0f, 1.0f, false);
726+
OnHeapGraphIndex graph = builder.build(vectorValues);
727+
728+
VectorFloat<?> query = vectorTypeSupport.createFloatVector(new float[]{0.5f, 0.5f});
729+
SearchResult result = GraphSearcher.search(query, N, vectorValues, similarityFunction, graph, Bits.ALL);
730+
731+
// In perfect world, we should return all N vectors, but this is hard to guarantee considering
732+
// the graph is built with a semi-randomized algorithm. And this is an edge case already, so
733+
// we don't want to make the graph building algorithm more complex or less performant in order to satisfy
734+
// this test. In a typical scenario we'll have many more vectors in the graph than the query limit,
735+
// so missing some results is fine. We'd fall back to brute force search anyway if limit
736+
// is the same order of magnitude as the graph size.
737+
int minExpected = (int) (N * 0.5);
738+
assertTrue("Should return almost all vectors, expected at least: " + minExpected + ", got: " + result.getNodes().length,
739+
result.getNodes().length >= minExpected);
740+
}
741+
705742
/**
706743
* Returns vectors evenly distributed around the upper unit semicircle.
707744
*/

0 commit comments

Comments
 (0)