From 66cb1528cf167b6011687ab2659d41eacf3ef47a Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Wed, 17 Sep 2025 12:48:11 -0700 Subject: [PATCH 1/5] Ensures that no node duplicates exist in the adjacency list of any node. --- .../jbellis/jvector/graph/NodeArray.java | 55 +++++++------------ 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java index 9650cece..272dcaeb 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java @@ -64,9 +64,8 @@ static NodeArray merge(NodeArray a1, NodeArray a2) { NodeArray merged = new NodeArray(a1.size() + a2.size()); int i = 0, j = 0; - // since nodes are only guaranteed to be sorted by score -- ties can appear in any node order -- - // we need to remember all the nodes with the current score to avoid adding duplicates - var nodesWithLastScore = new IntHashSet(); + // To avoid duplicates, we need to remember all the nodes added so far + var mergedNodes = new IntHashSet(); float lastAddedScore = Float.NaN; // loop through both source arrays, adding the highest score element to the merged array, @@ -75,33 +74,30 @@ static NodeArray merge(NodeArray a1, NodeArray a2) { if (a1.scores[i] < a2.scores[j]) { // add from a2 if (a2.scores[j] != lastAddedScore) { - nodesWithLastScore.clear(); lastAddedScore = a2.scores[j]; } - if (nodesWithLastScore.add(a2.nodes[j])) { + if (mergedNodes.add(a2.nodes[j])) { merged.addInOrder(a2.nodes[j], a2.scores[j]); } j++; } else if (a1.scores[i] > a2.scores[j]) { // add from a1 if (a1.scores[i] != lastAddedScore) { - nodesWithLastScore.clear(); lastAddedScore = a1.scores[i]; } - if (nodesWithLastScore.add(a1.nodes[i])) { + if (mergedNodes.add(a1.nodes[i])) { merged.addInOrder(a1.nodes[i], a1.scores[i]); } i++; } else { // same score -- add both if (a1.scores[i] != lastAddedScore) { - nodesWithLastScore.clear(); lastAddedScore = a1.scores[i]; } - if (nodesWithLastScore.add(a1.nodes[i])) { + if (mergedNodes.add(a1.nodes[i])) { merged.addInOrder(a1.nodes[i], a1.scores[i]); } - if (nodesWithLastScore.add(a2.nodes[j])) { + if (mergedNodes.add(a2.nodes[j])) { merged.addInOrder(a2.nodes[j], a2.scores[j]); } i++; @@ -112,30 +108,22 @@ static NodeArray merge(NodeArray a1, NodeArray a2) { // If elements remain in a1, add them if (i < a1.size()) { // avoid duplicates while adding nodes with the same score - while (i < a1.size && a1.scores[i] == lastAddedScore) { - if (!nodesWithLastScore.contains(a1.nodes[i])) { + for (; i < a1.size; j++) { + if (mergedNodes.add(a1.nodes[i])) { merged.addInOrder(a1.nodes[i], a1.scores[i]); } - i++; } - // the remaining nodes have a different score, so we can bulk-add them - System.arraycopy(a1.nodes, i, merged.nodes, merged.size, a1.size - i); - System.arraycopy(a1.scores, i, merged.scores, merged.size, a1.size - i); merged.size += a1.size - i; } // If elements remain in a2, add them if (j < a2.size()) { // avoid duplicates while adding nodes with the same score - while (j < a2.size && a2.scores[j] == lastAddedScore) { - if (!nodesWithLastScore.contains(a2.nodes[j])) { + for (; j < a2.size; j++) { + if (mergedNodes.add(a2.nodes[j])) { merged.addInOrder(a2.nodes[j], a2.scores[j]); } - j++; } - // the remaining nodes have a different score, so we can bulk-add them - System.arraycopy(a2.nodes, j, merged.nodes, merged.size, a2.size - j); - System.arraycopy(a2.scores, j, merged.scores, merged.size, a2.size - j); merged.size += a2.size - j; } @@ -169,7 +157,7 @@ public void addInOrder(int newNode, float newScore) { */ int insertionPoint(int newNode, float newScore) { int insertionPoint = descSortFindRightMostInsertionPoint(newScore); - return duplicateExistsNear(insertionPoint, newNode, newScore) ? -1 : insertionPoint; + return duplicateExists(insertionPoint, newNode) ? -1 : insertionPoint; } /** @@ -209,21 +197,20 @@ private int insertInternal(int insertionPoint, int newNode, float newScore) { return insertionPoint; } - private boolean duplicateExistsNear(int insertionPoint, int newNode, float newScore) { - // Check to the left - for (int i = insertionPoint - 1; i >= 0 && scores[i] == newScore; i--) { - if (nodes[i] == newNode) { - return true; + private boolean duplicateExists(int insertionPoint, int newNode) { + // Checking in close to the insertion point first should be better that doing a scan from 0 to size + for (int i = 0; i < size ; i++) { + if (insertionPoint >= i && insertionPoint < size) { + if (nodes[insertionPoint - i] == newNode) { + return true; + } } - } - - // Check to the right - for (int i = insertionPoint; i < size && scores[i] == newScore; i++) { - if (nodes[i] == newNode) { + if (insertionPoint + i < size) { + if(nodes[insertionPoint + i] == newNode) { return true; } + } } - return false; } From 6c68ad798d99cafce0d9e03a709b98a04d80bebf Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Wed, 17 Sep 2025 13:15:13 -0700 Subject: [PATCH 2/5] Fix grammar in comment --- .../main/java/io/github/jbellis/jvector/graph/NodeArray.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java index 272dcaeb..f880de69 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java @@ -198,7 +198,7 @@ private int insertInternal(int insertionPoint, int newNode, float newScore) { } private boolean duplicateExists(int insertionPoint, int newNode) { - // Checking in close to the insertion point first should be better that doing a scan from 0 to size + // Checking close to the insertion point first should be better that doing a scan from 0 to size for (int i = 0; i < size ; i++) { if (insertionPoint >= i && insertionPoint < size) { if (nodes[insertionPoint - i] == newNode) { From a483317f84aa5e418cfb70197f6d601326aa9dd2 Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Wed, 17 Sep 2025 13:35:07 -0700 Subject: [PATCH 3/5] Fix minor bugs --- .../java/io/github/jbellis/jvector/graph/NodeArray.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java index f880de69..137781db 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java @@ -108,7 +108,7 @@ static NodeArray merge(NodeArray a1, NodeArray a2) { // If elements remain in a1, add them if (i < a1.size()) { // avoid duplicates while adding nodes with the same score - for (; i < a1.size; j++) { + for (; i < a1.size; i++) { if (mergedNodes.add(a1.nodes[i])) { merged.addInOrder(a1.nodes[i], a1.scores[i]); } @@ -199,8 +199,8 @@ private int insertInternal(int insertionPoint, int newNode, float newScore) { private boolean duplicateExists(int insertionPoint, int newNode) { // Checking close to the insertion point first should be better that doing a scan from 0 to size - for (int i = 0; i < size ; i++) { - if (insertionPoint >= i && insertionPoint < size) { + for (int i = 0; i < size + 1; i++) { + if (insertionPoint >= i && insertionPoint - i < size) { if (nodes[insertionPoint - i] == newNode) { return true; } From e15f15360a0634a10a26f04a112be7ac3b1a92bb Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Fri, 19 Sep 2025 08:06:03 -0700 Subject: [PATCH 4/5] Overhaul strategy to have unique edges in the graph. --- .../jvector/graph/ConcurrentNeighborMap.java | 42 +---- .../jvector/graph/GraphIndexBuilder.java | 48 ++--- .../jbellis/jvector/graph/NodeArray.java | 165 ++++++++---------- .../graph/diversity/DiversityProvider.java | 2 +- .../diversity/VamanaDiversityProvider.java | 14 +- .../jvector/graph/GraphIndexBuilderTest.java | 9 +- .../jbellis/jvector/graph/TestNeighbors.java | 6 +- .../jbellis/jvector/graph/TestNodeArray.java | 72 ++++---- 8 files changed, 153 insertions(+), 205 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java index 891fda75..94879d27 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java @@ -16,7 +16,6 @@ package io.github.jbellis.jvector.graph; -import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.graph.diversity.DiversityProvider; import io.github.jbellis.jvector.util.BitSet; import io.github.jbellis.jvector.util.Bits; @@ -24,8 +23,6 @@ import io.github.jbellis.jvector.util.FixedBitSet; import io.github.jbellis.jvector.util.IntMap; -import static java.lang.Math.min; - /** * Encapsulates operations on a graph's neighbors. */ @@ -179,9 +176,6 @@ public static class Neighbors extends NodeArray { /** the node id whose neighbors we are storing */ private final int nodeId; - /** entries in `nodes` before this index are diverse and don't need to be checked again */ - private int diverseBefore; - /** * uses the node and score references directly from `nodeArray`, without copying * `nodeArray` is assumed to have had diversity enforced already @@ -189,7 +183,6 @@ public static class Neighbors extends NodeArray { private Neighbors(int nodeId, NodeArray nodeArray) { super(nodeArray); this.nodeId = nodeId; - this.diverseBefore = size(); } public NodesIterator iterator() { @@ -217,8 +210,7 @@ private NeighborWithShortEdges enforceDegree(ConcurrentNeighborMap map) { return new NeighborWithShortEdges(this, Double.NaN); } var next = copy(); - double shortEdges = retainDiverseInternal(next, diverseBefore, map); - next.diverseBefore = next.size(); + double shortEdges = retainDiverseInternal(next, map); return new NeighborWithShortEdges(next, shortEdges); } @@ -234,7 +226,7 @@ private Neighbors replaceDeletedNeighbors(Bits deletedNodes, NodeArray candidate // merge the remaining neighbors with the candidates and keep the diverse results NodeArray merged = NodeArray.merge(liveNeighbors, candidates); - retainDiverseInternal(merged, 0, map); + retainDiverseInternal(merged, map); return new Neighbors(nodeId, merged); } @@ -254,10 +246,10 @@ private Neighbors insertDiverse(NodeArray toMerge, ConcurrentNeighborMap map) { NodeArray merged; if (size() > 0) { merged = NodeArray.merge(this, toMerge); - retainDiverseInternal(merged, 0, map); + retainDiverseInternal(merged, map); } else { merged = toMerge.copy(); // still need to copy in case we lose the race - retainDiverseInternal(merged, 0, map); + retainDiverseInternal(merged, map); } // insertDiverse usually gets called with a LOT of candidates, so trim down the resulting NodeArray var nextNodes = merged.getArrayLength() <= map.nodeArrayLength() @@ -275,7 +267,6 @@ private Neighbors insertNotDiverse(int node, float score, ConcurrentNeighborMap // node already existed in the array -- this is rare enough that we don't check up front return this; } - next.diverseBefore = min(insertedAt, diverseBefore); return next; } @@ -283,9 +274,9 @@ private Neighbors insertNotDiverse(int node, float score, ConcurrentNeighborMap * Retain the diverse neighbors, updating `neighbors` in place * @return post-diversity short edges fraction */ - private double retainDiverseInternal(NodeArray neighbors, int diverseBefore, ConcurrentNeighborMap map) { + private double retainDiverseInternal(NodeArray neighbors, ConcurrentNeighborMap map) { BitSet selected = new FixedBitSet(neighbors.size()); - double shortEdges = map.diversityProvider.retainDiverse(neighbors, map.maxDegree, diverseBefore, selected); + double shortEdges = map.diversityProvider.retainDiverse(neighbors, map.maxDegree, selected); neighbors.retain(selected); return shortEdges; } @@ -302,7 +293,7 @@ private Neighbors insert(int neighborId, float score, float overflow, Concurrent assert hardMax <= map.maxOverflowDegree : String.format("overflow %s could exceed max overflow degree %d", overflow, map.maxOverflowDegree); - var insertionPoint = insertionPoint(neighborId, score); + var insertionPoint = insertionPoint(neighborId); if (insertionPoint == -1) { // "new" node already existed return null; @@ -313,10 +304,8 @@ private Neighbors insert(int neighborId, float score, float overflow, Concurrent // batch up the enforcement of the max connection limit, since otherwise // we do a lot of duplicate work scanning nodes that we won't remove - next.diverseBefore = min(insertionPoint, diverseBefore); if (next.size() > hardMax) { - retainDiverseInternal(next, next.diverseBefore, map); - next.diverseBefore = next.size(); + retainDiverseInternal(next, map); } return next; @@ -324,20 +313,7 @@ private Neighbors insert(int neighborId, float score, float overflow, Concurrent public static long ramBytesUsed(int count) { return NodeArray.ramBytesUsed(count) // includes our object header - + Integer.BYTES // nodeId - + Integer.BYTES; // diverseBefore - } - - /** Only for testing; this is a linear search */ - @VisibleForTesting - boolean contains(int i) { - var it = this.iterator(); - while (it.hasNext()) { - if (it.nextInt() == i) { - return true; - } - } - return false; + + Integer.BYTES; // nodeId } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index fea92bd7..744b041d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -63,8 +63,7 @@ public class GraphIndexBuilder implements Closeable { private static final Logger logger = LoggerFactory.getLogger(GraphIndexBuilder.class); private final int beamWidth; - private final ExplicitThreadLocal naturalScratch; - private final ExplicitThreadLocal concurrentScratch; + private final ExplicitThreadLocal scratch; private final int dimension; private final float neighborOverflow; @@ -335,8 +334,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, }); // in scratch we store candidates in reverse order: worse candidates are first - this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1))); - this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1))); + this.scratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(2 * max(beamWidth, graph.maxDegree() + 1))); this.rng = new Random(0); } @@ -377,10 +375,10 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi var newNeighbors = new NodeArray(oldNeighborsIt.size()); while (oldNeighborsIt.hasNext()) { int neighbor = oldNeighborsIt.nextInt(); - // since we're using a different score provider, use insertSorted instead of addInOrder - newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor)); + newNeighbors.addInOrder(neighbor, sf.similarityTo(neighbor)); } newBuilder.graph.addNode(lvl, i, newNeighbors); + newBuilder.graph.markComplete(new NodeAtLevel(lvl, i)); } }); }).join(); @@ -557,8 +555,7 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) { var inProgressBefore = insertionsInProgress.clone(); try (var gs = searchers.get()) { gs.setView(graph.getView()); // new snapshot - var naturalScratchPooled = naturalScratch.get(); - var concurrentScratchPooled = concurrentScratch.get(); + var scratchPooled = scratch.get(); var bits = new ExcludingBits(nodeLevel.node); var entry = graph.entry(); @@ -581,7 +578,7 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) { neighbors[index.getAndIncrement()] = new NodeScore(neighbor, score); }); Arrays.sort(neighbors); - updateNeighborsOneLayer(lvl, nodeLevel.node, neighbors, naturalScratchPooled, inProgressBefore, concurrentScratchPooled, searchScoreProvider); + updateNeighborsOneLayer(lvl, nodeLevel.node, neighbors, scratchPooled, inProgressBefore, searchScoreProvider); } gs.setEntryPointsFromPreviousLayer(); } @@ -590,7 +587,7 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) { result = gs.resume(beamWidth, beamWidth, 0.0f, 0.0f); } - updateNeighborsOneLayer(0, nodeLevel.node, result.getNodes(), naturalScratchPooled, inProgressBefore, concurrentScratchPooled, searchScoreProvider); + updateNeighborsOneLayer(0, nodeLevel.node, result.getNodes(), scratchPooled, inProgressBefore, searchScoreProvider); graph.markComplete(nodeLevel); } catch (Exception e) { @@ -602,16 +599,20 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) { return IntStream.range(0, nodeLevel.level).mapToLong(graph::ramBytesUsedOneNode).sum(); } - private void updateNeighborsOneLayer(int layer, int node, NodeScore[] neighbors, NodeArray naturalScratchPooled, ConcurrentSkipListSet inProgressBefore, NodeArray concurrentScratchPooled, SearchScoreProvider ssp) { + private void updateNeighborsOneLayer(int layer, int node, NodeScore[] neighbors, NodeArray scratchPooled, ConcurrentSkipListSet inProgressBefore, SearchScoreProvider ssp) { // Update neighbors with these candidates. // The DiskANN paper calls for using the entire set of visited nodes along the search path as // potential candidates, but in practice we observe neighbor lists being completely filled using // just the topK results. (Since the Robust Prune algorithm prioritizes closer neighbors, // this means that considering additional nodes from the search path, that are by definition // farther away than the ones in the topK, would not change the result.) - var natural = toScratchCandidates(neighbors, naturalScratchPooled); - var concurrent = getConcurrentCandidates(layer, node, inProgressBefore, concurrentScratchPooled, ssp.scoreFunction()); - updateNeighbors(layer, node, natural, concurrent); + scratchPooled.clear(); + toScratchCandidates(neighbors, scratchPooled); + getConcurrentCandidates(layer, node, inProgressBefore, scratchPooled, ssp.scoreFunction()); + + // toMerge may be approximate-scored, but insertDiverse will compute exact scores for the diverse ones + var graphNeighbors = graph.layers.get(layer).insertDiverse(node, scratchPooled); + graph.layers.get(layer).backlink(graphNeighbors, node, neighborOverflow); } @VisibleForTesting @@ -742,25 +743,9 @@ public synchronized long removeDeletedNodes() { return memorySize; } - private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) { - // if either natural or concurrent is empty, skip the merge - NodeArray toMerge; - if (concurrent.size() == 0) { - toMerge = natural; - } else if (natural.size() == 0) { - toMerge = concurrent; - } else { - toMerge = NodeArray.merge(natural, concurrent); - } - // toMerge may be approximate-scored, but insertDiverse will compute exact scores for the diverse ones - var neighbors = graph.layers.get(layer).insertDiverse(nodeId, toMerge); - graph.layers.get(layer).backlink(neighbors, nodeId, neighborOverflow); - } - private static NodeArray toScratchCandidates(NodeScore[] candidates, NodeArray scratch) { - scratch.clear(); for (var candidate : candidates) { - scratch.addInOrder(candidate.node, candidate.score); + scratch.insertSorted(candidate.node, candidate.score); } return scratch; } @@ -771,7 +756,6 @@ private NodeArray getConcurrentCandidates(int layer, NodeArray scratch, ScoreFunction scoreFunction) { - scratch.clear(); for (NodeAtLevel n : inProgress) { if (n.node == newNode || n.level < layer) { continue; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java index 137781db..c2baa8c6 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java @@ -28,17 +28,16 @@ import io.github.jbellis.jvector.util.ArrayUtil; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.RamUsageEstimator; -import org.agrona.collections.IntHashSet; import java.util.Arrays; +import java.util.stream.IntStream; import static java.lang.Math.min; /** - * NodeArray encodes nodeids and their scores relative to some other element + * NodeArray encodes node IDs and their scores relative to some other element * (a query vector, or another graph node) as a pair of growable arrays. - * Nodes are arranged in the sorted order of their scores in descending order, - * i.e. the most-similar nodes are first. + * Nodes are arranged in ascending order using the nodeIDs. */ public class NodeArray { public static final NodeArray EMPTY = new NodeArray(0); @@ -50,6 +49,7 @@ public class NodeArray { public NodeArray(int initialSize) { nodes = new int[initialSize]; scores = new float[initialSize]; + size = 0; } // this idiosyncratic constructor exists for the benefit of subclass ConcurrentNeighborMap @@ -59,74 +59,46 @@ protected NodeArray(NodeArray nodeArray) { this.scores = nodeArray.scores; } - /** always creates a new NodeArray to return, even when a1 or a2 is empty */ + /** always creates a new NodeArray to return, even when a1 or a2 is empty. + * If a node ID is present in both, the one from a1 will be added. */ static NodeArray merge(NodeArray a1, NodeArray a2) { NodeArray merged = new NodeArray(a1.size() + a2.size()); - int i = 0, j = 0; - - // To avoid duplicates, we need to remember all the nodes added so far - var mergedNodes = new IntHashSet(); - float lastAddedScore = Float.NaN; + int i1 = 0, i2 = 0; // loop through both source arrays, adding the highest score element to the merged array, // until we reach the end of one of the sources - while (i < a1.size() && j < a2.size()) { - if (a1.scores[i] < a2.scores[j]) { - // add from a2 - if (a2.scores[j] != lastAddedScore) { - lastAddedScore = a2.scores[j]; - } - if (mergedNodes.add(a2.nodes[j])) { - merged.addInOrder(a2.nodes[j], a2.scores[j]); - } - j++; - } else if (a1.scores[i] > a2.scores[j]) { + while (i1 < a1.size() && i2 < a2.size()) { + if (a1.nodes[i1] < a2.nodes[i2]) { // add from a1 - if (a1.scores[i] != lastAddedScore) { - lastAddedScore = a1.scores[i]; - } - if (mergedNodes.add(a1.nodes[i])) { - merged.addInOrder(a1.nodes[i], a1.scores[i]); - } - i++; + merged.addInOrder(a1.nodes[i1], a1.scores[i1]); + i1++; + } else if (a1.nodes[i1] > a2.nodes[i2]) { + // add from a2 + merged.addInOrder(a2.nodes[i2], a2.scores[i2]); + i2++; } else { - // same score -- add both - if (a1.scores[i] != lastAddedScore) { - lastAddedScore = a1.scores[i]; - } - if (mergedNodes.add(a1.nodes[i])) { - merged.addInOrder(a1.nodes[i], a1.scores[i]); - } - if (mergedNodes.add(a2.nodes[j])) { - merged.addInOrder(a2.nodes[j], a2.scores[j]); - } - i++; - j++; + // same node -- add from a1 + merged.addInOrder(a1.nodes[i1], a1.scores[i1]); + i1++; + i2++; } } // If elements remain in a1, add them - if (i < a1.size()) { - // avoid duplicates while adding nodes with the same score - for (; i < a1.size; i++) { - if (mergedNodes.add(a1.nodes[i])) { - merged.addInOrder(a1.nodes[i], a1.scores[i]); - } + if (i1 < a1.size()) { + for (; i1 < a1.size; i1++) { + merged.addInOrder(a1.nodes[i1], a1.scores[i1]); } - merged.size += a1.size - i; + merged.size += a1.size - i1; } // If elements remain in a2, add them - if (j < a2.size()) { - // avoid duplicates while adding nodes with the same score - for (; j < a2.size; j++) { - if (mergedNodes.add(a2.nodes[j])) { - merged.addInOrder(a2.nodes[j], a2.scores[j]); - } + if (i2 < a2.size()) { + for (; i2 < a2.size; i2++) { + merged.addInOrder(a2.nodes[i2], a2.scores[i2]); } - merged.size += a2.size - j; + merged.size += a2.size - i2; } - return merged; } @@ -139,12 +111,12 @@ public void addInOrder(int newNode, float newScore) { growArrays(); } if (size > 0) { - float previousScore = scores[size - 1]; - assert ((previousScore >= newScore)) + int previousNode = nodes[size - 1]; + assert ((previousNode <= newNode)) : "Nodes are added in the incorrect order! Comparing " - + newScore + + newNode + " to " - + Arrays.toString(ArrayUtil.copyOfSubArray(scores, 0, size)); + + Arrays.toString(ArrayUtil.copyOfSubArray(nodes, 0, size)); } nodes[size] = newNode; scores[size] = newScore; @@ -153,10 +125,10 @@ public void addInOrder(int newNode, float newScore) { /** * Returns the index at which the given node should be inserted to maintain sorted order, - * or -1 if the node already exists in the array (with the same score). + * or -1 if the node already exists in the array. */ - int insertionPoint(int newNode, float newScore) { - int insertionPoint = descSortFindRightMostInsertionPoint(newScore); + int insertionPoint(int newNode) { + int insertionPoint = incSortFindRightMostInsertionPoint(newNode); return duplicateExists(insertionPoint, newNode) ? -1 : insertionPoint; } @@ -170,7 +142,7 @@ public int insertSorted(int newNode, float newScore) { if (size == nodes.length) { growArrays(); } - int insertionPoint = insertionPoint(newNode, newScore); + int insertionPoint = insertionPoint(newNode); if (insertionPoint == -1) { return -1; } @@ -197,21 +169,10 @@ private int insertInternal(int insertionPoint, int newNode, float newScore) { return insertionPoint; } - private boolean duplicateExists(int insertionPoint, int newNode) { - // Checking close to the insertion point first should be better that doing a scan from 0 to size - for (int i = 0; i < size + 1; i++) { - if (insertionPoint >= i && insertionPoint - i < size) { - if (nodes[insertionPoint - i] == newNode) { - return true; - } - } - if (insertionPoint + i < size) { - if(nodes[insertionPoint + i] == newNode) { - return true; - } - } - } - return false; + private boolean duplicateExists(int insertionPoint, int node) { + if (insertionPoint < size && nodes[insertionPoint] == node) return true; + if (insertionPoint + 1 < size && nodes[insertionPoint + 1] == node) return true; + return insertionPoint - 1 >= 0 && nodes[insertionPoint - 1] == node; } /** @@ -292,12 +253,12 @@ public String toString() { return sb.toString(); } - protected final int descSortFindRightMostInsertionPoint(float newScore) { + protected final int incSortFindRightMostInsertionPoint(int newNode) { int start = 0; int end = size - 1; while (start <= end) { int mid = (start + end) / 2; - if (scores[mid] < newScore) end = mid - 1; + if (nodes[mid] > newNode) end = mid - 1; else start = mid + 1; } return start; @@ -315,17 +276,10 @@ public static long ramBytesUsed(int size) { + (long) size * (Integer.BYTES + Float.BYTES); // array contents } - /** - * Caution! This performs a linear scan. - */ @VisibleForTesting boolean contains(int node) { - for (int i = 0; i < size; i++) { - if (this.nodes[i] == node) { - return true; - } - } - return false; + int insertionPoint = incSortFindRightMostInsertionPoint(node); + return duplicateExists(insertionPoint, node); } @VisibleForTesting @@ -358,4 +312,37 @@ public int getNode(int i) { protected int getArrayLength() { return nodes.length; } + + public NodesIterator getIteratorSortedByScores() { + return new ScoreSortedNeighborIterator(this); + } + + private static class ScoreSortedNeighborIterator implements NodesIterator { + private final NodeArray array; + private final int[] sortedIndices; + private int i; + + private ScoreSortedNeighborIterator(NodeArray array) { + this.array = array; + sortedIndices = IntStream.range(0, this.array.size()) + .boxed().sorted((i, j) -> Float.compare(this.array.getScore(j), this.array.getScore(i))) + .mapToInt(ele -> ele).toArray(); + i = 0; + } + + @Override + public int size() { + return array.size(); + } + + @Override + public boolean hasNext() { + return i < array.size(); + } + + @Override + public int nextInt() { + return sortedIndices[i++]; + } + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/DiversityProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/DiversityProvider.java index 7551aec7..af9baeaf 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/DiversityProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/DiversityProvider.java @@ -29,5 +29,5 @@ public interface DiversityProvider { * update `selected` with the diverse members of `neighbors`. `neighbors` is not modified * @return the fraction of short edges (neighbors within alpha=1.0) */ - double retainDiverse(NodeArray neighbors, int maxDegree, int diverseBefore, BitSet selected); + double retainDiverse(NodeArray neighbors, int maxDegree, BitSet selected); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java index 0bdc6415..5406abcc 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java @@ -17,6 +17,7 @@ package io.github.jbellis.jvector.graph.diversity; import io.github.jbellis.jvector.graph.NodeArray; +import io.github.jbellis.jvector.graph.NodesIterator; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.util.BitSet; @@ -42,18 +43,17 @@ public VamanaDiversityProvider(BuildScoreProvider scoreProvider, float alpha) { * It assumes that the i-th neighbor with 0 {@literal <=} i {@literal <} diverseBefore is already diverse. * @return the fraction of short edges (neighbors within alpha=1.0) */ - public double retainDiverse(NodeArray neighbors, int maxDegree, int diverseBefore, BitSet selected) { - for (int i = 0; i < min(diverseBefore, maxDegree); i++) { - selected.set(i); - } - - int nSelected = diverseBefore; + public double retainDiverse(NodeArray neighbors, int maxDegree, BitSet selected) { + int nSelected = 0; double shortEdges = Double.NaN; // add diverse candidates, gradually increasing alpha to the threshold // (so that the nearest candidates are prioritized) float currentAlpha = 1.0f; while (currentAlpha <= alpha + 1E-6 && nSelected < maxDegree) { - for (int i = diverseBefore; i < neighbors.size() && nSelected < maxDegree; i++) { + NodesIterator it = neighbors.getIteratorSortedByScores(); + while (it.hasNext() && nSelected < maxDegree) { + int i = it.nextInt(); + if (selected.get(i)) { continue; } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java index 4c82a652..1ec9767d 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java @@ -102,10 +102,11 @@ public void testRescore(boolean addHierarchy) { // Check node 0's neighbors, score and order should be different var newNeighbors = newGraph.getNeighbors(0, 0); - assertEquals(2, newNeighbors.getNode(0)); - assertEquals(1, newNeighbors.getNode(1)); - assertEquals(0.2f, newNeighbors.getScore(0), 1E-6); - assertEquals(0.05882353f, newNeighbors.getScore(1), 1E-6); + assertEquals(1, newNeighbors.getNode(0)); + assertEquals(2, newNeighbors.getNode(1)); + assertEquals(0.05882353f, newNeighbors.getScore(0), 1E-6); + assertEquals(0.2f, newNeighbors.getScore(1), 1E-6); + } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNeighbors.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNeighbors.java index f5d68702..2f70df47 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNeighbors.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNeighbors.java @@ -24,7 +24,7 @@ import java.util.stream.IntStream; -import static io.github.jbellis.jvector.graph.TestNodeArray.validateSortedByScore; +import static io.github.jbellis.jvector.graph.TestNodeArray.validateSortedByNode; import static org.junit.Assert.assertEquals; public class TestNeighbors extends RandomizedTest { @@ -49,7 +49,7 @@ public void testInsertDiverse() { assertEquals(2, neighbors.size()); assert neighbors.contains(8); assert neighbors.contains(6); - validateSortedByScore(neighbors); + validateSortedByNode(neighbors); } private static float scoreBetween(BuildScoreProvider bsp, int i, int j) { @@ -78,7 +78,7 @@ public void testInsertDiverseConcurrent() { assertEquals(2, neighbors.size()); assert neighbors.contains(8); assert neighbors.contains(6); - validateSortedByScore(neighbors); + validateSortedByNode(neighbors); } @Test diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java index 0f358620..e9aae751 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java @@ -25,12 +25,10 @@ package io.github.jbellis.jvector.graph; import com.carrotsearch.randomizedtesting.RandomizedTest; -import io.github.jbellis.jvector.util.ArrayUtil; import io.github.jbellis.jvector.util.FixedBitSet; import org.junit.Assert; import org.junit.Test; -import java.util.Arrays; import java.util.HashSet; import static org.junit.Assert.assertArrayEquals; @@ -39,9 +37,9 @@ import static org.junit.jupiter.api.Assertions.assertThrows; public class TestNodeArray extends RandomizedTest { - static void validateSortedByScore(NodeArray na) { + static void validateSortedByNode(NodeArray na) { for (int i = 0; i < na.size() - 1; i++) { - assertTrue(na.getScore(i) >= na.getScore(i + 1)); + assertTrue(na.getNode(i) < na.getNode(i + 1)); } } @@ -51,53 +49,56 @@ public void testScoresDescOrder() { neighbors.addInOrder(0, 1); neighbors.addInOrder(1, 0.8f); - AssertionError ex = assertThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f)); + AssertionError ex = assertThrows(AssertionError.class, () -> neighbors.addInOrder(-1, 0.9f)); assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage(); neighbors.insertSorted(3, 0.9f); assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 1}, neighbors); + assertNodesEqual(new int[] {0, 1, 3}, neighbors); neighbors.insertSorted(4, 1f); assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors); + assertNodesEqual(new int[] {0, 1, 3, 4}, neighbors); neighbors.insertSorted(5, 1.1f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors); + assertNodesEqual(new int[] {0, 1, 3, 4, 5}, neighbors); neighbors.insertSorted(6, 0.8f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors); + assertNodesEqual(new int[] {0, 1, 3, 4, 5, 6}, neighbors); neighbors.insertSorted(7, 0.8f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {0, 1, 3, 4, 5, 6, 7}, neighbors); neighbors.removeIndex(2); - assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors); + assertScoresEqual(new float[] {1.1f, 1, 1, 0.8f, 0.8f, 0.8f}, neighbors); + assertNodesEqual(new int[] {0, 1, 4, 5, 6, 7}, neighbors); neighbors.removeIndex(0); - assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors); + NodesIterator it = neighbors.getIteratorSortedByScores(); + assertScoresEqual(new float[] {1.1f, 1, 0.8f, 0.8f, 0.8f}, neighbors); + assertNodesEqual(new int[] {1, 4, 5, 6, 7}, neighbors); neighbors.removeIndex(4); - assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors); + assertScoresEqual(new float[] {1.1f, 1, 0.8f, 0.8f}, neighbors); + assertNodesEqual(new int[] {1, 4, 5, 6}, neighbors); neighbors.removeLast(); - assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 1}, neighbors); + assertScoresEqual(new float[] {1.1f, 1, 0.8f}, neighbors); + assertNodesEqual(new int[] {1, 4, 5}, neighbors); neighbors.insertSorted(8, 0.9f); - assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors); + assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f}, neighbors); + assertNodesEqual(new int[] {1, 4, 5, 8}, neighbors); } private void assertScoresEqual(float[] scores, NodeArray neighbors) { - for (int i = 0; i < scores.length; i++) { - assertEquals(scores[i], neighbors.getScore(i), 0.01f); + NodesIterator it = neighbors.getIteratorSortedByScores(); + int i = 0; + while (it.hasNext()) { + assertEquals(scores[i++], neighbors.getScore(it.nextInt()), 0.01f); } } @@ -170,7 +171,7 @@ public void testNoDuplicatesDescOrder() { cna.insertSorted(3, 8.0f); // This is also a duplicate Assert.assertArrayEquals(new int[] {1, 2, 3}, cna.copyDenseNodes()); assertArrayEquals(new float[] {10.0f, 9.0f, 8.0f}, cna.copyDenseScores(), 0.01f); - validateSortedByScore(cna); + validateSortedByNode(cna); } @Test @@ -183,7 +184,7 @@ public void testNoDuplicatesSameScores() { cna.insertSorted(3, 10.0f); // This is also a duplicate assertArrayEquals(new int[] {1, 2, 3}, cna.copyDenseNodes()); assertArrayEquals(new float[] {10.0f, 10.0f, 10.0f}, cna.copyDenseScores(), 0.01f); - validateSortedByScore(cna); + validateSortedByNode(cna); } @Test @@ -199,33 +200,32 @@ public void testMergeCandidatesSimple() { assertArrayEquals(new int[] {0, 1}, merged.copyDenseNodes()); arr1 = new NodeArray(3); - arr1.addInOrder(3, 3.0f); - arr1.addInOrder(2, 2.0f); arr1.addInOrder(1, 1.0f); + arr1.addInOrder(2, 2.0f); + arr1.addInOrder(3, 3.0f); arr2 = new NodeArray(3); - arr2.addInOrder(4, 4.0f); + arr2.addInOrder(1, 1.05f); arr2.addInOrder(2, 2.0f); - arr2.addInOrder(1, 1.0f); + arr2.addInOrder(4, 4.0f); merged = NodeArray.merge(arr1, arr2); - // Expected result: [4, 3, 2, 1] - assertArrayEquals(new int[] {4, 3, 2, 1}, merged.copyDenseNodes()); - assertArrayEquals(new float[] {4.0f, 3.0f, 2.0f, 1.0f}, merged.copyDenseScores(), 0.0f); + assertArrayEquals(new int[] {1, 2, 3, 4}, merged.copyDenseNodes()); + assertScoresEqual(new float[] {4.0f, 3.0f, 2.0f, 1.0f}, merged); // Testing boundary conditions arr1 = new NodeArray(2); - arr1.addInOrder(3, 3.0f); arr1.addInOrder(2, 2.0f); + arr1.addInOrder(3, 3.0f); arr2 = new NodeArray(1); arr2.addInOrder(2, 2.0f); merged = NodeArray.merge(arr1, arr2); // Expected result: [3, 2] - assertArrayEquals(new int[] {3, 2}, merged.copyDenseNodes()); - assertArrayEquals(new float[] {3.0f, 2.0f}, merged.copyDenseScores(), 0.0f); - validateSortedByScore(merged); + assertArrayEquals(new int[] {2, 3}, merged.copyDenseNodes()); + assertScoresEqual(new float[] {3.0f, 2.0f}, merged); + validateSortedByNode(merged); } private void testMergeCandidatesOnce() { @@ -274,7 +274,7 @@ private void testMergeCandidatesOnce() { // results should be sorted by score, and not contain duplicates for (int i = 0; i < merged.size() - 1; i++) { - assertTrue(merged.getScore(i) >= merged.getScore(i + 1)); + assertTrue(merged.getNode(i) < merged.getNode(i + 1)); assertTrue(uniqueNodes.add(merged.getNode(i))); } assertTrue(uniqueNodes.add(merged.getNode(merged.size() - 1))); From 76f37033fd30c28ddab207ed96e51a8f7b78555d Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Fri, 19 Sep 2025 08:17:32 -0700 Subject: [PATCH 5/5] Since we removed one integer from the adjacency list, we need to adjust the expected size down by 4 in GraphIndexBuilderTest.testEstimatedBytes --- .../io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java index 840e5877..4468e129 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java @@ -68,7 +68,7 @@ public void testEstimatedBytes() throws IOException { try (var builder = new GraphIndexBuilder(bsp, 2, 2, 10, 1.0f, 1.0f, false)) { var bytesUsed = builder.addGraphNode(0, ravv.getVector(0)); // The actual value is not critical, but this confirms we do not get unexpected changes (for this config) - assertEquals(92, bytesUsed); + assertEquals(88, bytesUsed); } }