diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java index cd0ef61a6..b71591f33 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java @@ -52,7 +52,7 @@ public class RecallWithRandomVectorsBenchmark { private ArrayList> baseVectors; private ArrayList> queryVectors; private GraphIndexBuilder graphIndexBuilder; - private GraphIndex graphIndex; + private ImmutableGraphIndex graphIndex; private PQVectors pqVectors; // Add ground truth storage diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java index f1477e7ea..a3651aabc 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java @@ -45,7 +45,7 @@ public class StaticSetVectorsBenchmark { private List> queryVectors; private List> groundTruth; private GraphIndexBuilder graphIndexBuilder; - private GraphIndex graphIndex; + private ImmutableGraphIndex graphIndex; int originalDimension; @Setup diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 6c9960ba4..2cff2de4a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -18,7 +18,7 @@ import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.RandomAccessReader; -import io.github.jbellis.jvector.graph.GraphIndex.NodeAtLevel; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex.NodeAtLevel; import io.github.jbellis.jvector.graph.SearchResult.NodeScore; import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; @@ -30,7 +30,6 @@ import io.github.jbellis.jvector.util.PhysicalCoreExecutor; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; -import org.agrona.collections.IntArrayList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,7 +48,7 @@ import static java.lang.Math.*; /** - * Builder for Concurrent GraphIndex. See {@link GraphIndex} for a high level overview, and the + * Builder for Concurrent GraphIndex. See {@link ImmutableGraphIndex} for a high level overview, and the * comments to `addGraphNode` for details on the concurrent building approach. *

* GIB allocates scratch space and copies of the RandomAccessVectorValues for each thread @@ -71,7 +70,7 @@ public class GraphIndexBuilder implements Closeable { private final boolean refineFinalGraph; @VisibleForTesting - final OnHeapGraphIndex graph; + final MutableGraphIndex graph; private final ConcurrentSkipListSet insertionsInProgress = new ConcurrentSkipListSet<>(); @@ -343,7 +342,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) { var newBuilder = new GraphIndexBuilder(newProvider, other.dimension, - other.graph.maxDegrees, + other.graph.maxDegrees(), other.beamWidth, other.neighborOverflow, other.alpha, @@ -352,17 +351,13 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi other.simdExecutor, other.parallelExecutor); + var otherView = other.graph.getView(); + // Copy each node and its neighbors from the old graph to the new one other.parallelExecutor.submit(() -> { IntStream.range(0, other.graph.getIdUpperBound()).parallel().forEach(i -> { // Find the highest layer this node exists in - int maxLayer = -1; - for (int lvl = 0; lvl < other.graph.layers.size(); lvl++) { - if (other.graph.getNeighbors(lvl, i) == null) { - break; - } - maxLayer = lvl; - } + int maxLayer = other.graph.getMaxLevelForNode(i); if (maxLayer < 0) { return; } @@ -370,7 +365,7 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi // Loop over 0..maxLayer, re-score neighbors for each layer var sf = newProvider.searchProviderFor(i).scoreFunction(); for (int lvl = 0; lvl <= maxLayer; lvl++) { - var oldNeighborsIt = other.graph.getNeighborsIterator(lvl, i); + var oldNeighborsIt = otherView.getNeighborsIterator(lvl, i); // Copy edges, compute new scores var newNeighbors = new NodeArray(oldNeighborsIt.size()); while (oldNeighborsIt.hasNext()) { @@ -378,18 +373,18 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi // since we're using a different score provider, use insertSorted instead of addInOrder newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor)); } - newBuilder.graph.addNode(lvl, i, newNeighbors); + newBuilder.graph.connectNode(lvl, i, newNeighbors); } }); }).join(); // Set the entry node - newBuilder.graph.updateEntryNode(other.graph.entry()); + newBuilder.graph.updateEntryNode(otherView.entryNode()); return newBuilder; } - public OnHeapGraphIndex build(RandomAccessVectorValues ravv) { + public ImmutableGraphIndex build(RandomAccessVectorValues ravv) { var vv = ravv.threadLocalSupplier(); int size = ravv.size(); @@ -403,6 +398,19 @@ public OnHeapGraphIndex build(RandomAccessVectorValues ravv) { return graph; } + /** + * Validates that the current entry node has been completely added. + */ + void validateEntryNode() { + if (graph.size(0) == 0) { + return; + } + NodeAtLevel entry = graph.entryNode(); + if (entry == null || !graph.getView().contains(entry.level, entry.node)) { + throw new IllegalStateException("Entry node was incompletely added! " + entry); + } + } + /** * Cleanup the graph by completing removal of marked-for-delete nodes, trimming * neighbor sets to the advertised degree, and updating the entry node. @@ -417,7 +425,7 @@ public void cleanup() { if (graph.size(0) == 0) { return; } - graph.validateEntryNode(); // sanity check before we start + validateEntryNode(); // sanity check before we start // purge deleted nodes. // backlinks can cause neighbors to soft-overflow, so do this before neighbors cleanup @@ -442,34 +450,35 @@ public void cleanup() { // clean up overflowed neighbor lists parallelExecutor.submit(() -> { IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(id -> { - for (int layer = 0; layer < graph.layers.size(); layer++) { - graph.layers.get(layer).enforceDegree(id); + for (int layer = 0; layer <= graph.getMaxLevel(); layer++) { + graph.enforceDegree(id); } }); }).join(); + + graph.allMutationsCompleted(); } private void improveConnections(int node) { var ssp = scoreProvider.searchProviderFor(node); var bits = new ExcludingBits(node); try (var gs = searchers.get()) { - gs.initializeInternal(ssp, graph.entry(), bits); + gs.initializeInternal(ssp, graph.entryNode(), bits); var acceptedBits = Bits.intersectionOf(bits, gs.getView().liveNodes()); // Move downward from entry.level to 0 - for (int lvl = graph.entry().level; lvl >= 0; lvl--) { + for (int lvl = graph.entryNode().level; lvl >= 0; lvl--) { // This additional call seems redundant given that we have already initialized an ssp above. // However, there is a subtle interplay between the ssp of the search and the ssp used in insertDiverse. // Do not remove this line. ssp = scoreProvider.searchProviderFor(node); - if (graph.layers.get(lvl).get(node) != null) { + if (graph.getNeighborsIterator(lvl, node).size() > 0) { gs.searchOneLayer(ssp, beamWidth, 0.0f, lvl, acceptedBits); var candidates = new NodeArray(gs.approximateResults.size()); gs.approximateResults.foreach(candidates::insertSorted); - var newNeighbors = graph.layers.get(lvl).insertDiverse(node, candidates); - graph.layers.get(lvl).backlink(newNeighbors, node, neighborOverflow); + graph.addEdges(lvl, node, candidates, neighborOverflow); } else { gs.searchOneLayer(ssp, 1, 0.0f, lvl, acceptedBits); } @@ -480,7 +489,7 @@ private void improveConnections(int node) { } } - public OnHeapGraphIndex getGraph() { + public ImmutableGraphIndex getGraph() { return graph; } @@ -554,12 +563,13 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) { insertionsInProgress.add(nodeLevel); var inProgressBefore = insertionsInProgress.clone(); try (var gs = searchers.get()) { - gs.setView(graph.getView()); // new snapshot + var view = graph.getView(); + gs.setView(view); // new snapshot var naturalScratchPooled = naturalScratch.get(); var concurrentScratchPooled = concurrentScratch.get(); var bits = new ExcludingBits(nodeLevel.node); - var entry = graph.entry(); + var entry = view.entryNode(); SearchResult result; if (entry == null) { result = new SearchResult(new NodeScore[] {}, 0, 0, 0, 0, 0); @@ -600,7 +610,7 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) { return IntStream.rangeClosed(0, nodeLevel.level).mapToLong(graph::ramBytesUsedOneNode).sum(); } - private void updateNeighborsOneLayer(int layer, int node, NodeScore[] neighbors, NodeArray naturalScratchPooled, ConcurrentSkipListSet inProgressBefore, NodeArray concurrentScratchPooled, SearchScoreProvider ssp) { + private void updateNeighborsOneLayer(int level, int node, NodeScore[] neighbors, NodeArray naturalScratchPooled, ConcurrentSkipListSet inProgressBefore, NodeArray concurrentScratchPooled, SearchScoreProvider ssp) { // Update neighbors with these candidates. // The DiskANN paper calls for using the entire set of visited nodes along the search path as // potential candidates, but in practice we observe neighbor lists being completely filled using @@ -608,8 +618,8 @@ private void updateNeighborsOneLayer(int layer, int node, NodeScore[] neighbors, // this means that considering additional nodes from the search path, that are by definition // farther away than the ones in the topK, would not change the result.) var natural = toScratchCandidates(neighbors, naturalScratchPooled); - var concurrent = getConcurrentCandidates(layer, node, inProgressBefore, concurrentScratchPooled, ssp.scoreFunction()); - updateNeighbors(layer, node, natural, concurrent); + var concurrent = getConcurrentCandidates(level, node, inProgressBefore, concurrentScratchPooled, ssp.scoreFunction()); + updateNeighbors(level, node, natural, concurrent); } @VisibleForTesting @@ -636,7 +646,7 @@ public synchronized long removeDeletedNodes() { return 0; } - for (int currentLevel = 0; currentLevel < graph.layers.size(); currentLevel++) { + for (int currentLevel = 0; currentLevel <= graph.getMaxLevel(); currentLevel++) { final int level = currentLevel; // Create effectively final copy for lambda // Compute new edges to insert. If node j is deleted, we add edges (i, k) // whenever (i, j) and (j, k) are directed edges in the current graph. This @@ -687,10 +697,10 @@ public synchronized long removeDeletedNodes() { // doing actual sampling-without-replacement is expensive so we'll loop a fixed number of times instead for (int i = 0; i < 2 * graph.getDegree(level); i++) { int randomNode = R.nextInt(graph.getIdUpperBound()); - while(toDelete.get(randomNode)) { + while (toDelete.get(randomNode)) { randomNode = R.nextInt(graph.getIdUpperBound()); } - if (randomNode != node && !candidates.contains(randomNode) && graph.layers.get(level).contains(randomNode)) { + if (randomNode != node && !candidates.contains(randomNode) && graph.contains(level, randomNode)) { float score = sf.similarityTo(randomNode); candidates.insertSorted(randomNode, score); } @@ -701,14 +711,14 @@ public synchronized long removeDeletedNodes() { } // remove edges to deleted nodes and add the new connections, maintaining diversity - graph.layers.get(level).replaceDeletedNeighbors(node, toDelete, candidates); + graph.replaceDeletedNeighbors(level, node, toDelete, candidates); }); }).join(); } // Generally we want to keep entryPoint update and node removal distinct, because both can be expensive, // but if the entry point was deleted then we have no choice - if (toDelete.get(graph.entry().node)) { + if (toDelete.get(graph.entryNode().node)) { // pick a random node at the top layer int newLevel = graph.getMaxLevel(); int newEntry = -1; @@ -740,7 +750,7 @@ public synchronized long removeDeletedNodes() { return memorySize; } - private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) { + private void updateNeighbors(int level, int nodeId, NodeArray natural, NodeArray concurrent) { // if either natural or concurrent is empty, skip the merge NodeArray toMerge; if (concurrent.size() == 0) { @@ -751,8 +761,7 @@ private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray toMerge = NodeArray.merge(natural, concurrent); } // toMerge may be approximate-scored, but insertDiverse will compute exact scores for the diverse ones - var neighbors = graph.layers.get(layer).insertDiverse(nodeId, toMerge); - graph.layers.get(layer).backlink(neighbors, nodeId, neighborOverflow); + graph.addEdges(level, nodeId, toMerge, neighborOverflow); } private static NodeArray toScratchCandidates(NodeScore[] candidates, NodeArray scratch) { @@ -763,7 +772,7 @@ private static NodeArray toScratchCandidates(NodeScore[] candidates, NodeArray s return scratch; } - private NodeArray getConcurrentCandidates(int layer, + private NodeArray getConcurrentCandidates(int level, int newNode, Set inProgress, NodeArray scratch, @@ -771,7 +780,7 @@ private NodeArray getConcurrentCandidates(int layer, { scratch.clear(); for (NodeAtLevel n : inProgress) { - if (n.node == newNode || n.level < layer) { + if (n.node == newNode || n.level < level) { continue; } scratch.insertSorted(n.node, scoreFunction.similarityTo(n.node)); @@ -801,6 +810,7 @@ public boolean get(int index) { } } + @Deprecated public void load(RandomAccessReader in) throws IOException { if (graph.size(0) != 0) { throw new IllegalStateException("Cannot load into a non-empty graph"); @@ -819,6 +829,7 @@ public void load(RandomAccessReader in) throws IOException { } } + @Deprecated private void loadV4(RandomAccessReader in) throws IOException { if (graph.size(0) != 0) { throw new IllegalStateException("Cannot load into a non-empty graph"); @@ -851,7 +862,7 @@ private void loadV4(RandomAccessReader in) throws IOException { int neighbor = in.readInt(); ca.addInOrder(neighbor, sf.similarityTo(neighbor)); } - graph.addNode(level, nodeId, ca); + graph.connectNode(level, nodeId, ca); nodeLevelMap.put(nodeId, level); } } @@ -865,7 +876,7 @@ private void loadV4(RandomAccessReader in) throws IOException { graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); } - + @Deprecated private void loadV3(RandomAccessReader in, int size) throws IOException { if (graph.size() != 0) { throw new IllegalStateException("Cannot load into a non-empty graph"); @@ -891,7 +902,7 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { int neighbor = in.readInt(); ca.addInOrder(neighbor, sf.similarityTo(neighbor)); } - graph.addNode(0, nodeId, ca); + graph.connectNode(0, nodeId, ca); graph.markComplete(new NodeAtLevel(0, nodeId)); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java index 1674ffda4..60a91c9ae 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java @@ -25,7 +25,7 @@ package io.github.jbellis.jvector.graph; import io.github.jbellis.jvector.annotations.Experimental; -import io.github.jbellis.jvector.graph.GraphIndex.NodeAtLevel; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex.NodeAtLevel; import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; @@ -43,10 +43,10 @@ /** * Searches a graph to find nearest neighbors to a query vector. For more background on the - * search algorithm, see {@link GraphIndex}. + * search algorithm, see {@link ImmutableGraphIndex}. */ public class GraphSearcher implements Closeable { - private GraphIndex.View view; + private ImmutableGraphIndex.View view; // Scratch data structures that are used in each {@link #searchInternal} call. These can be expensive // to allocate, so they're cleared and reused across calls. @@ -71,14 +71,14 @@ public class GraphSearcher implements Closeable { /** * Creates a new graph searcher from the given GraphIndex */ - public GraphSearcher(GraphIndex graph) { + public GraphSearcher(ImmutableGraphIndex graph) { this(graph.getView()); } /** * Creates a new graph searcher from the given GraphIndex.View */ - protected GraphSearcher(GraphIndex.View view) { + protected GraphSearcher(ImmutableGraphIndex.View view) { this.view = view; this.candidates = new NodeQueue(new GrowableLongHeap(100), NodeQueue.Order.MAX_HEAP); this.evictedResults = new NodesUnsorted(100); @@ -112,7 +112,7 @@ private void initializeScoreProvider(SearchScoreProvider scoreProvider) { cachingReranker = new CachingReranker(scoreProvider); } - public GraphIndex.View getView() { + public ImmutableGraphIndex.View getView() { return view; } @@ -129,7 +129,7 @@ public void usePruning(boolean usage) { * Convenience function for simple one-off searches. It is caller's responsibility to make sure that it * is the unique owner of the vectors instance passed in here. */ - public static SearchResult search(VectorFloat queryVector, int topK, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, GraphIndex graph, Bits acceptOrds) { + public static SearchResult search(VectorFloat queryVector, int topK, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, ImmutableGraphIndex graph, Bits acceptOrds) { try (var searcher = new GraphSearcher(graph)) { var ssp = DefaultSearchScoreProvider.exact(queryVector, similarityFunction, vectors); return searcher.search(ssp, topK, acceptOrds); @@ -142,7 +142,7 @@ public static SearchResult search(VectorFloat queryVector, int topK, RandomAc * Convenience function for simple one-off searches. It is caller's responsibility to make sure that it * is the unique owner of the vectors instance passed in here. */ - public static SearchResult search(VectorFloat queryVector, int topK, int rerankK, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, GraphIndex graph, Bits acceptOrds) { + public static SearchResult search(VectorFloat queryVector, int topK, int rerankK, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, ImmutableGraphIndex graph, Bits acceptOrds) { try (var searcher = new GraphSearcher(graph)) { var ssp = DefaultSearchScoreProvider.exact(queryVector, similarityFunction, vectors); return searcher.search(ssp, topK, rerankK, 0.f, 0.f, acceptOrds); @@ -160,7 +160,7 @@ public static SearchResult search(VectorFloat queryVector, int topK, int rera * * @param view the new view */ - public void setView(GraphIndex.View view) { + public void setView(ImmutableGraphIndex.View view) { this.view = view; } @@ -169,9 +169,9 @@ public void setView(GraphIndex.View view) { */ @Deprecated public static class Builder { - private final GraphIndex.View view; + private final ImmutableGraphIndex.View view; - public Builder(GraphIndex.View view) { + public Builder(ImmutableGraphIndex.View view) { this.view = view; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java similarity index 92% rename from jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java index 602bf71f1..088f9a1af 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java @@ -29,6 +29,8 @@ import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; + +import java.util.List; import java.util.Objects; import java.io.Closeable; @@ -44,7 +46,7 @@ * All methods are threadsafe. Operations that require persistent state are wrapped * in a View that should be created per accessing thread. */ -public interface GraphIndex extends AutoCloseable, Accountable { +public interface ImmutableGraphIndex extends AutoCloseable, Accountable { /** Returns the number of nodes in the graph */ @Deprecated default int size() { @@ -76,6 +78,8 @@ default int size() { */ int maxDegree(); + List maxDegrees(); + /** * @return the first ordinal greater than all node ids in the graph. Equal to size() in simple cases; * May be different from size() if nodes are being added concurrently, or if nodes have been @@ -107,6 +111,14 @@ default boolean containsNode(int nodeId) { */ int getDegree(int level); + /** + * Returns the average degree computed over nodes in the specified layer. + * + * @param level the level of interest. + * @return the average degree or NaN if no nodes are present. + */ + double getAverageDegree(int level); + /** * Return the number of vectors/nodes in the given level. * @param level The level of interest @@ -150,6 +162,11 @@ interface View extends Closeable { default int getIdUpperBound() { return size(); } + + /** + * Whether the given node is present in the given layer of the graph. + */ + boolean contains(int level, int node); } /** @@ -161,7 +178,7 @@ interface ScoringView extends View { ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat queryVector, VectorSimilarityFunction vsf); } - static String prettyPrint(GraphIndex graph) { + static String prettyPrint(ImmutableGraphIndex graph) { StringBuilder sb = new StringBuilder(); sb.append(graph); sb.append("\n"); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java new file mode 100644 index 000000000..2e88e6dd4 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java @@ -0,0 +1,170 @@ +/* + * All changes to the original code are Copyright DataStax, Inc. + * + * Please see the included license file for details. + */ + +/* + * Original license: + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph; + +import io.github.jbellis.jvector.util.BitSet; +import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet; + +import java.util.List; +import java.util.stream.IntStream; + + +/** + * An {@link ImmutableGraphIndex} that offers concurrent access; for typical graphs you will get significant + * speedups in construction and searching as you add threads. + * + *

The base layer (layer 0) contains all nodes, while higher layers are stored in sparse maps. + * For searching, use a view obtained from {@link #getView()} which supports level–aware operations. + */ +interface MutableGraphIndex extends ImmutableGraphIndex { + /** + * Add the given node ordinal with an empty set of neighbors. + * + *

Nodes can be inserted out of order, but it requires that the nodes preceded by the node + * inserted out of order are eventually added. + * + *

Actually populating the neighbors, and establishing bidirectional links, is the + * responsibility of the caller. + * + *

It is also the responsibility of the caller to ensure that each node is only added once. + */ + void addNode(NodeAtLevel nodeLevel); + + /** + * Add the given node ordinal with an empty set of neighbors. + * + *

Nodes can be inserted out of order, but it requires that the nodes preceded by the node + * inserted out of order are eventually added. + * + *

Actually populating the neighbors, and establishing bidirectional links, is the + * responsibility of the caller. + * + *

It is also the responsibility of the caller to ensure that each node is only added once. + */ + void addNode(int level, int node); + + /** + * Whether the given node is present in the graph. + */ + boolean contains(NodeAtLevel nodeLevel); + + /** + * Whether the given node is present in the given layer of the graph. + */ + boolean contains(int level, int node); + + /** + * Add the given node ordinal with an empty set of neighbors. + * + *

Nodes can be inserted out of order, but it requires that the nodes preceded by the node + * inserted out of order are eventually added. + * + *

Actually populating the neighbors, and establishing bidirectional links, is the + * responsibility of the caller. + * + *

It is also the responsibility of the caller to ensure that each node is only added once. + */ + void connectNode(int level, int node, NodeArray nodes); + + /** + * Use with extreme caution. Used by Builder to load a saved graph and for rescoring. + */ + void connectNode(NodeAtLevel nodeLevel, NodeArray nodes); + + /** + * Mark the given node deleted. Does NOT remove the node from the graph. + */ + void markDeleted(int node); + + /** must be called after addNode once neighbors are linked in all levels. */ + void markComplete(NodeAtLevel nodeLevel); + + void updateEntryNode(NodeAtLevel newEntry); + + /** + * Returns an upper bound on the amount of memory used by a single node, in bytes. + */ + long ramBytesUsedOneNode(int layer); + + ThreadSafeGrowableBitSet getDeletedNodes(); + + void setDegrees(List layerDegrees); + + /** + * Enforce the degree of the given node in all layers. + */ + void enforceDegree(int node); + + /** + * Returns an iterator over the neighbors for the given node at the specified level, which can be empty if the node does not belong to that layer. + */ + NodesIterator getNeighborsIterator(NodeAtLevel nodeLevel); + + /** + * Returns an iterator over the neighbors for the given node at the specified level, which can be empty if the node does not belong to that layer. + */ + NodesIterator getNeighborsIterator(int level, int node); + + /** + * Removes the given node from all layers. + * + * @param node the node id to remove + * @return the number of layers from which it was removed + */ + int removeNode(int node); + + /** + * Returns an Integer stream with the nodes contained in the specified level. + */ + IntStream nodeStream(int level); + + /** + * Returns the maximum (coarser) level that contains a vector in the graph or -1 if the node is not in the graph. + */ + int getMaxLevelForNode(int node); + + /** + * @return the node of the graph to start searches at + */ + NodeAtLevel entryNode(); + + /** + * Add the given neighbors to the given node at the specified level, maintaining diversity + * It also adds backlinks from the neighbors to the given node. + * The edges will only be added if the out-degree of the node is less than overflowRatio times the max degree. + */ + void addEdges(int level, int node, NodeArray candidates, float overflowRatio); + + /** + * Remove edges to deleted nodes and add the new connections, maintaining diversity + */ + void replaceDeletedNeighbors(int level, int node, BitSet toDelete, NodeArray candidates); + + /** + * Signals that all mutations have been completed and the graph will not be mutated any further. + * Should be called by the builder after all mutations are completed (during cleanup). + */ + void allMutationsCompleted(); +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 69f3527f4..7ddbf7897 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -26,14 +26,13 @@ import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors; import io.github.jbellis.jvector.graph.diversity.DiversityProvider; -import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.util.Accountable; +import io.github.jbellis.jvector.util.BitSet; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.DenseIntMap; import io.github.jbellis.jvector.util.RamUsageEstimator; import io.github.jbellis.jvector.util.SparseIntMap; import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet; -import io.github.jbellis.jvector.vector.types.VectorFloat; import org.agrona.collections.IntArrayList; import java.io.DataOutput; @@ -42,8 +41,6 @@ import java.util.ArrayList; import java.util.List; import java.util.NoSuchElementException; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.concurrent.atomic.AtomicReference; @@ -51,13 +48,13 @@ import java.util.stream.IntStream; /** - * An {@link GraphIndex} that offers concurrent access; for typical graphs you will get significant + * An {@link ImmutableGraphIndex} that offers concurrent access; for typical graphs you will get significant * speedups in construction and searching as you add threads. * *

The base layer (layer 0) contains all nodes, while higher layers are stored in sparse maps. * For searching, use a view obtained from {@link #getView()} which supports level–aware operations. */ -public class OnHeapGraphIndex implements GraphIndex { +public class OnHeapGraphIndex implements MutableGraphIndex { // Used for saving and loading OnHeapGraphIndex public static final int MAGIC = 0x75EC4012; // JVECTOR, with some imagination @@ -72,11 +69,13 @@ public class OnHeapGraphIndex implements GraphIndex { private final AtomicInteger maxNodeId = new AtomicInteger(-1); // Maximum number of neighbors (edges) per node per layer - final IntArrayList maxDegrees; + final List maxDegrees; // The ratio by which we can overflow the neighborhood of a node during construction. Since it is a multiplicative // ratio, i.e., the maximum allowable degree if maxDegree * overflowRatio, it should be higher than 1. private final double overflowRatio; + private volatile boolean allMutationsCompleted = false; + OnHeapGraphIndex(List maxDegrees, double overflowRatio, DiversityProvider diversityProvider) { this.overflowRatio = overflowRatio; this.maxDegrees = new IntArrayList(); @@ -106,14 +105,13 @@ Neighbors getNeighbors(int level, int node) { return layers.get(level).get(node); } - /** - * Returns an iterator over the neighbors for the given node at the specified level. - * - * @param level the layer - * @param node the node id - * @return a NodesIterator, which can be empty - */ - NodesIterator getNeighborsIterator(int level, int node) { + @Override + public NodesIterator getNeighborsIterator(NodeAtLevel nodeAtLevel) { + return getNeighborsIterator(nodeAtLevel.level, nodeAtLevel.node); + } + + @Override + public NodesIterator getNeighborsIterator(int level, int node) { if (level >= layers.size()) { return NodesIterator.EMPTY_NODE_ITERATOR; } @@ -125,30 +123,45 @@ NodesIterator getNeighborsIterator(int level, int node) { } } + @Override + public int getMaxLevelForNode(int node) { + int maxLayer = -1; + for (int lvl = 0; lvl < layers.size(); lvl++) { + if (getNeighbors(lvl, node) == null) { + break; + } + maxLayer = lvl; + } + return maxLayer; + } + @Override public int size(int level) { return layers.get(level).size(); } - /** - * Add the given node ordinal with an empty set of neighbors. - * - *

Nodes can be inserted out of order, but it requires that the nodes preceded by the node - * inserted out of order are eventually added. - * - *

Actually populating the neighbors, and establishing bidirectional links, is the - * responsibility of the caller. - * - *

It is also the responsibility of the caller to ensure that each node is only added once. - */ public void addNode(NodeAtLevel nodeLevel) { - ensureLayersExist(nodeLevel.level); + addNode(nodeLevel.level, nodeLevel.node); + } + + public void addNode(int level, int node) { + ensureLayersExist(level); // add the node to each layer - for (int i = 0; i <= nodeLevel.level; i++) { - layers.get(i).addNode(nodeLevel.node); + for (int i = 0; i <= level; i++) { + layers.get(i).addNode(node); } - maxNodeId.accumulateAndGet(nodeLevel.node, Math::max); + maxNodeId.accumulateAndGet(node, Math::max); + } + + @Override + public boolean contains(NodeAtLevel nodeLevel) { + return contains(nodeLevel.level, nodeLevel.node); + } + + @Override + public boolean contains(int level, int node) { + return layers.get(level).contains(node); } private void ensureLayersExist(int level) { @@ -166,14 +179,15 @@ private void ensureLayersExist(int level) { } } - /** - * Only for use by Builder loading a saved graph - */ - void addNode(int level, int nodeId, NodeArray nodes) { + public void connectNode(NodeAtLevel nodeLevel, NodeArray nodes) { + connectNode(nodeLevel.level, nodeLevel.node, nodes); + } + + public void connectNode(int level, int node, NodeArray nodes) { assert nodes != null; ensureLayersExist(level); - this.layers.get(level).addNode(nodeId, nodes); - maxNodeId.accumulateAndGet(nodeId, Math::max); + this.layers.get(level).addNode(node, nodes); + maxNodeId.accumulateAndGet(node, Math::max); } /** @@ -183,8 +197,7 @@ public void markDeleted(int node) { deletedNodes.set(node); } - /** must be called after addNode once neighbors are linked in all levels. */ - void markComplete(NodeAtLevel nodeLevel) { + public void markComplete(NodeAtLevel nodeLevel) { entryPoint.accumulateAndGet( nodeLevel, (oldEntry, newEntry) -> { @@ -197,11 +210,12 @@ void markComplete(NodeAtLevel nodeLevel) { completions.markComplete(nodeLevel.node); } - void updateEntryNode(NodeAtLevel newEntry) { + public void updateEntryNode(NodeAtLevel newEntry) { entryPoint.set(newEntry); } - NodeAtLevel entry() { + @Override + public NodeAtLevel entryNode() { return entryPoint.get(); } @@ -211,11 +225,8 @@ public NodesIterator getNodes(int level) { layers.get(level).size()); } - /** - * this does call get() internally to filter level 0, so if you're going to use it in a pipeline - * that also calls get(), consider using your own raw IntStream.range instead - */ - IntStream nodeStream(int level) { + @Override + public IntStream nodeStream(int level) { var layer = layers.get(level); return level == 0 ? IntStream.range(0, getIdUpperBound()).filter(i -> layer.get(i) != null) @@ -228,59 +239,57 @@ public long ramBytesUsed() { return graphBytesUsed + completions.ramBytesUsed(); } - public long ramBytesUsedOneLayer(int layer) { + private long ramBytesUsedOneLayer(int level) { int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; var REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF; var AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; - long neighborSize = ramBytesUsedOneNode(layer) * layers.get(layer).size(); + long neighborSize = ramBytesUsedOneNode(level) * layers.get(level).size(); return OH_BYTES + REF_BYTES * 2L + AH_BYTES + neighborSize; } - public long ramBytesUsedOneNode(int layer) { + public long ramBytesUsedOneNode(int level) { // we include the REF_BYTES for the CNS reference here to make it self-contained for addGraphNode() int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF; - return REF_BYTES + Neighbors.ramBytesUsed(layers.get(layer).nodeArrayLength()); + return REF_BYTES + Neighbors.ramBytesUsed(layers.get(level).nodeArrayLength()); } @Override - public String toString() { - return String.format("OnHeapGraphIndex(size=%d, entryPoint=%s)", size(0), entryPoint.get()); + public void enforceDegree(int node) { + for (int level = 0; level <= getMaxLevel(); level++) { + layers.get(level).enforceDegree(node); + } } @Override - public void close() { - // No resources to close. + public void addEdges(int level, int node, NodeArray candidates, float overflowRatio) { + var newNeighbors = layers.get(level).insertDiverse(node, candidates); + layers.get(level).backlink(newNeighbors, node, overflowRatio); } - /** - * Returns a view of the graph that is safe to use concurrently with updates performed on the - * underlying graph. - * - *

Multiple Views may be searched concurrently. - */ @Override - public ConcurrentGraphIndexView getView() { - return new ConcurrentGraphIndexView(); + public void replaceDeletedNeighbors(int level, int node, BitSet toDelete, NodeArray candidates) { + layers.get(level).replaceDeletedNeighbors(node, toDelete, candidates); } - /** - * A View that assumes no concurrent modifications are made - */ - public GraphIndex.View getFrozenView() { - return new FrozenView(); + @Override + public String toString() { + return String.format("OnHeapGraphIndex(size=%d, entryPoint=%s)", size(0), entryPoint.get()); } - /** - * Validates that the current entry node has been completely added. - */ - void validateEntryNode() { - if (size(0) == 0) { - return; - } - NodeAtLevel entry = getView().entryNode(); - if (entry == null || getNeighbors(entry.level, entry.node) == null) { - throw new IllegalStateException("Entry node was incompletely added! " + entry); + @Override + public void close() { + // No resources to close. + } + + @Override + public View getView() { + // Before all completions are completed, we need a View that is thread-safe and allows concurrent mutations in the graph. + // Once all completions are completed, we can freeze the graph and just need a View that is thread-safe. + if (allMutationsCompleted) { + return new FrozenView(); + } else { + return new ConcurrentGraphIndexView(); } } @@ -288,13 +297,8 @@ public ThreadSafeGrowableBitSet getDeletedNodes() { return deletedNodes; } - /** - * Removes the given node from all layers. - * - * @param node the node id to remove - * @return the number of layers from which it was removed - */ - int removeNode(int node) { + @Override + public int removeNode(int node) { int found = 0; for (var layer : layers) { if (layer.remove(node) != null) { @@ -351,15 +355,23 @@ public int maxDegree() { return maxDegrees.stream().mapToInt(i -> i).max().orElseThrow(); } - public int getLayerSize(int level) { - return layers.get(level).size(); + @Override + public List maxDegrees() { + return maxDegrees; } + @Override public void setDegrees(List layerDegrees) { maxDegrees.clear(); maxDegrees.addAll(layerDegrees); } + @Override + public void allMutationsCompleted() { + allMutationsCompleted = true; + } + + /** * A concurrent View of the graph that is safe to search concurrently with updates and with other * searches. The View provides a limited kind of snapshot isolation: only nodes completely added @@ -400,7 +412,15 @@ private int advance() { @Override public int size() { - throw new UnsupportedOperationException(); + NodesIterator it = OnHeapGraphIndex.this.getNeighborsIterator(level, node); + int size = 0; + while (it.hasNext()) { + int n = it.nextInt(); + if (completions.completedAt(n) < timestamp) { + size++; + } + } + return size; } @Override @@ -450,6 +470,11 @@ public int getIdUpperBound() { return OnHeapGraphIndex.this.getIdUpperBound(); } + @Override + public boolean contains(int level, int node) { + return OnHeapGraphIndex.this.contains(level, node); + } + @Override public void close() { // No resources to close @@ -465,6 +490,7 @@ public String toString() { /** * Saves the graph to the given DataOutput for reloading into memory later */ + @Deprecated public void save(DataOutput out) { if (deletedNodes.cardinality() > 0) { throw new IllegalStateException("Cannot save a graph that has deleted nodes. Call cleanup() first"); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java index 0a471e74b..761024ff8 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java @@ -17,9 +17,7 @@ package io.github.jbellis.jvector.graph.disk; import io.github.jbellis.jvector.disk.IndexWriter; -import io.github.jbellis.jvector.disk.RandomAccessWriter; -import io.github.jbellis.jvector.graph.GraphIndex; -import io.github.jbellis.jvector.graph.OnHeapGraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.disk.feature.*; import org.agrona.collections.Int2IntHashMap; @@ -37,8 +35,7 @@ public abstract class AbstractGraphIndexWriter implements public static final int FOOTER_MAGIC_SIZE = Integer.BYTES; // The size of the magic number in the footer public static final int FOOTER_SIZE = FOOTER_MAGIC_SIZE + FOOTER_OFFSET_SIZE; // The total size of the footer final int version; - final GraphIndex graph; - final GraphIndex.View view; + final ImmutableGraphIndex graph; final OrdinalMapper ordinalMapper; final int dimension; // we don't use Map features but EnumMap is the best way to make sure we don't @@ -51,7 +48,7 @@ public abstract class AbstractGraphIndexWriter implements AbstractGraphIndexWriter(T out, int version, - GraphIndex graph, + ImmutableGraphIndex graph, OrdinalMapper oldToNewOrdinals, int dimension, EnumMap features) @@ -61,7 +58,6 @@ public abstract class AbstractGraphIndexWriter implements } this.version = version; this.graph = graph; - this.view = graph instanceof OnHeapGraphIndex ? ((OnHeapGraphIndex) graph).getFrozenView() : graph.getView(); this.ordinalMapper = oldToNewOrdinals; this.dimension = dimension; this.featureMap = features; @@ -105,7 +101,7 @@ boolean isSeparated(Feature feature) { * if i < j in `graph` then map[i] < map[j] in the returned map. "Holes" left by * deleted nodes are filled in by shifting down the new ordinals. */ - public static Map sequentialRenumbering(GraphIndex graph) { + public static Map sequentialRenumbering(ImmutableGraphIndex graph) { try (var view = graph.getView()) { Int2IntHashMap oldToNewMap = new Int2IntHashMap(-1); int nextOrdinal = 0; @@ -133,7 +129,7 @@ public static Map sequentialRenumbering(GraphIndex graph) { * @param headerOffset the offset of the header in the slice * @throws IOException IOException */ - void writeFooter(long headerOffset) throws IOException { + void writeFooter(ImmutableGraphIndex.View view, long headerOffset) throws IOException { var layerInfo = CommonHeader.LayerInfo.fromGraph(graph, ordinalMapper); var commonHeader = new CommonHeader(version, dimension, @@ -155,7 +151,7 @@ void writeFooter(long headerOffset) throws IOException { * Public so that you can write the index size (and thus usefully open an OnDiskGraphIndex against the index) * to read Features from it before writing the edges. */ - public synchronized void writeHeader(long startOffset) throws IOException { + public synchronized void writeHeader(ImmutableGraphIndex.View view, long startOffset) throws IOException { // graph-level properties var layerInfo = CommonHeader.LayerInfo.fromGraph(graph, ordinalMapper); var commonHeader = new CommonHeader(version, @@ -168,7 +164,7 @@ public synchronized void writeHeader(long startOffset) throws IOException { assert out.position() == startOffset + headerSize : String.format("%d != %d", out.position(), startOffset + headerSize); } - void writeSparseLevels() throws IOException { + void writeSparseLevels(ImmutableGraphIndex.View view) throws IOException { // write sparse levels for (int level = 1; level <= graph.getMaxLevel(); level++) { int layerSize = graph.size(level); @@ -237,13 +233,13 @@ void writeSeparatedFeatures(Map> featureSt * T - the type of the output stream */ public abstract static class Builder, T extends IndexWriter> { - final GraphIndex graphIndex; + final ImmutableGraphIndex graphIndex; final EnumMap features; final T out; OrdinalMapper ordinalMapper; int version; - public Builder(GraphIndex graphIndex, T out) { + public Builder(ImmutableGraphIndex graphIndex, T out) { this.graphIndex = graphIndex; this.out = out; this.features = new EnumMap<>(FeatureId.class); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java index 6853e77ee..5d0a1aecb 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java @@ -19,12 +19,10 @@ import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; -import io.github.jbellis.jvector.disk.RandomAccessWriter; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.DataOutput; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -174,7 +172,7 @@ public LayerInfo(int size, int degree) { this.degree = degree; } - public static List fromGraph(GraphIndex graph, OrdinalMapper mapper) { + public static List fromGraph(ImmutableGraphIndex graph, OrdinalMapper mapper) { return IntStream.rangeClosed(0, graph.getMaxLevel()) .mapToObj(i -> new LayerInfo(graph.size(i), graph.getDegree(i))) .collect(Collectors.toList()); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index 6239ecc1b..a597aa78f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -19,7 +19,7 @@ import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.disk.ReaderSupplier; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.NodesIterator; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.disk.feature.Feature; @@ -61,7 +61,7 @@ * This graph may be extended with additional features, which are stored inline in the graph and in headers. * At runtime, this class may choose the best way to use these features. */ -public class OnDiskGraphIndex implements GraphIndex, AutoCloseable, Accountable +public class OnDiskGraphIndex implements ImmutableGraphIndex, AutoCloseable, Accountable { private static final Logger logger = LoggerFactory.getLogger(OnDiskGraphIndex.class); public static final int CURRENT_VERSION = 5; @@ -239,6 +239,11 @@ public int getDegree(int level) { return layerInfo.get(level).degree; } + @Override + public List maxDegrees() { + return layerInfo.stream().map(l -> l.degree).collect(Collectors.toList()); + } + @Override public int getIdUpperBound() { return idUpperBound; @@ -326,6 +331,18 @@ public View getView() { } } + @Override + public double getAverageDegree(int level) { + var view = this.getView(); + var it = this.getNodes(level); + long sum = 0; + while (it.hasNext()) { + int node = it.next(); + sum += view.getNeighborsIterator(level, node).size(); + } + return (double) sum / it.size(); + } + public class View implements FeatureSource, ScoringView, RandomAccessVectorValues { protected final RandomAccessReader reader; private final int[] neighbors; @@ -457,6 +474,21 @@ public int getIdUpperBound() { return idUpperBound; } + @Override + public boolean contains(int level, int node) { + try { + if (level == 0) { + return node < idUpperBound; + } else { + // For levels > 0, read from memory + var imn = getInMemoryLayers(reader); + return imn.get(level).containsKey(node); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + @Override public Bits liveNodes() { return Bits.ALL; @@ -489,12 +521,12 @@ public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(Vector } /** Convenience function for writing a vanilla DiskANN-style index with no extra Features. */ - public static void write(GraphIndex graph, RandomAccessVectorValues vectors, Path path) throws IOException { + public static void write(ImmutableGraphIndex graph, RandomAccessVectorValues vectors, Path path) throws IOException { write(graph, vectors, OnDiskGraphIndexWriter.sequentialRenumbering(graph), path); } /** Convenience function for writing a vanilla DiskANN-style index with no extra Features. */ - public static void write(GraphIndex graph, + public static void write(ImmutableGraphIndex graph, RandomAccessVectorValues vectors, Map oldToNewOrdinals, Path path) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java index 9dc6121ba..a8515c191 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java @@ -18,15 +18,10 @@ import io.github.jbellis.jvector.disk.BufferedRandomAccessWriter; import io.github.jbellis.jvector.disk.RandomAccessWriter; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.OnHeapGraphIndex; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; -import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; -import io.github.jbellis.jvector.graph.disk.feature.NVQ; -import io.github.jbellis.jvector.graph.disk.feature.SeparatedFeature; -import io.github.jbellis.jvector.graph.disk.feature.SeparatedNVQ; -import io.github.jbellis.jvector.graph.disk.feature.SeparatedVectors; import java.io.FileNotFoundException; import java.io.IOException; @@ -75,7 +70,7 @@ public class OnDiskGraphIndexWriter extends AbstractGraphIndexWriter features) @@ -90,7 +85,6 @@ public class OnDiskGraphIndexWriter extends AbstractGraphIndexWriter> featur throw new IllegalStateException(msg); } - writeHeader(); // sets position to start writing features + var view = graph.getView(); + + writeHeader(view); // sets position to start writing features // for each graph node, write the associated features, followed by its neighbors at L0 for (int newOrdinal = 0; newOrdinal <= ordinalMapper.maxOrdinal(); newOrdinal++) { @@ -217,20 +213,21 @@ public synchronized void write(Map> featur } // We will use the abstract method because no random access is needed - writeSparseLevels(); + writeSparseLevels(view); // We will use the abstract method because no random access is needed writeSeparatedFeatures(featureStateSuppliers); // Write the header again with updated offsets if (version >= 5) { - writeFooter(out.position()); + writeFooter(view, out.position()); } final var endOfGraphPosition = out.position(); - writeHeader(); + writeHeader(view); out.seek(endOfGraphPosition); out.flush(); + view.close(); } /** @@ -239,10 +236,10 @@ public synchronized void write(Map> featur * seek to the startOffset and re-write the header. * @throws IOException if there is an error writing the header */ - public synchronized void writeHeader() throws IOException { + public synchronized void writeHeader(ImmutableGraphIndex.View view) throws IOException { // graph-level properties out.seek(startOffset); - super.writeHeader(startOffset); + super.writeHeader(view, startOffset); out.flush(); } @@ -258,11 +255,11 @@ public synchronized long checksum() throws IOException { public static class Builder extends AbstractGraphIndexWriter.Builder { private long startOffset = 0L; - public Builder(GraphIndex graphIndex, Path outPath) throws FileNotFoundException { + public Builder(ImmutableGraphIndex graphIndex, Path outPath) throws FileNotFoundException { this(graphIndex, new BufferedRandomAccessWriter(outPath)); } - public Builder(GraphIndex graphIndex, RandomAccessWriter out) { + public Builder(ImmutableGraphIndex graphIndex, RandomAccessWriter out) { super(graphIndex, out); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java index c7591f2ef..e9afd4b41 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java @@ -17,17 +17,14 @@ package io.github.jbellis.jvector.graph.disk; import io.github.jbellis.jvector.disk.IndexWriter; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.OnHeapGraphIndex; import io.github.jbellis.jvector.graph.disk.feature.*; import java.io.IOException; import java.util.EnumMap; -import java.util.List; import java.util.Map; -import java.util.Set; import java.util.function.IntFunction; -import java.util.stream.Collectors; /** * Writes a graph index to disk in a format that can be loaded as an OnDiskGraphIndex. @@ -64,7 +61,7 @@ public class OnDiskSequentialGraphIndexWriter extends AbstractGraphIndexWriter features) @@ -74,7 +71,6 @@ public class OnDiskSequentialGraphIndexWriter extends AbstractGraphIndexWriter> featureStateSuppliers) throws IOException { - final var startOffset = out.position(); - writeHeader(startOffset); if (graph instanceof OnHeapGraphIndex) { var ohgi = (OnHeapGraphIndex) graph; if (ohgi.getDeletedNodes().cardinality() > 0) { @@ -108,6 +102,11 @@ public synchronized void write(Map> featur throw new IllegalStateException(msg); } + var view = graph.getView(); + + final var startOffset = out.position(); + writeHeader(view, startOffset); + // for each graph node, write the associated features, followed by its neighbors at L0 for (int newOrdinal = 0; newOrdinal <= ordinalMapper.maxOrdinal(); newOrdinal++) { var originalOrdinal = ordinalMapper.newToOld(newOrdinal); @@ -157,20 +156,22 @@ public synchronized void write(Map> featur } } - writeSparseLevels(); + writeSparseLevels(view); writeSeparatedFeatures(featureStateSuppliers); // Write the footer with all the metadata info about the graph - writeFooter(out.position()); + writeFooter(view, out.position()); // Note: flushing the data output is the responsibility of the caller we are not going to make assumptions about further uses of the data outputs + + view.close(); } /** * Builder for {@link OnDiskSequentialGraphIndexWriter}, with optional features. */ public static class Builder extends AbstractGraphIndexWriter.Builder { - public Builder(GraphIndex graphIndex, IndexWriter out) { + public Builder(ImmutableGraphIndex graphIndex, IndexWriter out) { super(graphIndex, out); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java index 7ff4bdff9..59ca11564 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java @@ -17,7 +17,7 @@ package io.github.jbellis.jvector.graph.disk.feature; import io.github.jbellis.jvector.disk.RandomAccessReader; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.disk.CommonHeader; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; @@ -118,11 +118,11 @@ public void writeInline(DataOutput out, Feature.State state_) throws IOException } public static class State implements Feature.State { - public final GraphIndex.View view; + public final ImmutableGraphIndex.View view; public final PQVectors pqVectors; public final int nodeId; - public State(GraphIndex.View view, PQVectors pqVectors, int nodeId) { + public State(ImmutableGraphIndex.View view, PQVectors pqVectors, int nodeId) { this.view = view; this.pqVectors = pqVectors; this.nodeId = nodeId; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java index bdf91beb7..a4d62645f 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java @@ -29,10 +29,9 @@ import io.github.jbellis.jvector.example.util.CompressorParameters; import io.github.jbellis.jvector.example.util.DataSet; import io.github.jbellis.jvector.example.util.FilteredForkJoinPool; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.GraphIndexBuilder; import io.github.jbellis.jvector.graph.GraphSearcher; -import io.github.jbellis.jvector.graph.OnHeapGraphIndex; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; @@ -71,7 +70,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.ForkJoinPool; import java.util.function.Function; import java.util.function.IntFunction; import java.util.stream.IntStream; @@ -160,7 +158,7 @@ static void runOneGraph(List> featureSets, DataSet ds, Path testDirectory) throws IOException { - Map, GraphIndex> indexes; + Map, ImmutableGraphIndex> indexes; if (buildCompressor == null) { indexes = buildInMemory(featureSets, M, efConstruction, neighborOverflow, addHierarchy, refineFinalGraph, ds, testDirectory); } else { @@ -199,15 +197,15 @@ static void runOneGraph(List> featureSets, } } - private static Map, GraphIndex> buildOnDisk(List> featureSets, - int M, - int efConstruction, - float neighborOverflow, - boolean addHierarchy, - boolean refineFinalGraph, - DataSet ds, - Path testDirectory, - VectorCompressor buildCompressor) + private static Map, ImmutableGraphIndex> buildOnDisk(List> featureSets, + int M, + int efConstruction, + float neighborOverflow, + boolean addHierarchy, + boolean refineFinalGraph, + DataSet ds, + Path testDirectory, + VectorCompressor buildCompressor) throws IOException { var floatVectors = ds.getBaseRavv(); @@ -282,7 +280,7 @@ private static Map, GraphIndex> buildOnDisk(List, GraphIndex> indexes = new HashMap<>(); + Map, ImmutableGraphIndex> indexes = new HashMap<>(); n = 0; for (var features : featureSets) { var graphPath = testDirectory.resolve("graph" + n++); @@ -293,7 +291,7 @@ private static Map, GraphIndex> buildOnDisk(List features, - OnHeapGraphIndex onHeapGraph, + ImmutableGraphIndex onHeapGraph, Path outPath, RandomAccessVectorValues floatVectors, ProductQuantization pq) @@ -358,18 +356,18 @@ public BuilderWithSuppliers(OnDiskGraphIndexWriter.Builder builder, Map, GraphIndex> buildInMemory(List> featureSets, - int M, - int efConstruction, - float neighborOverflow, - boolean addHierarchy, - boolean refineFinalGraph, - DataSet ds, - Path testDirectory) + private static Map, ImmutableGraphIndex> buildInMemory(List> featureSets, + int M, + int efConstruction, + float neighborOverflow, + boolean addHierarchy, + boolean refineFinalGraph, + DataSet ds, + Path testDirectory) throws IOException { var floatVectors = ds.getBaseRavv(); - Map, GraphIndex> indexes = new HashMap<>(); + Map, ImmutableGraphIndex> indexes = new HashMap<>(); long start; var bsp = BuildScoreProvider.randomAccessScoreProvider(floatVectors, ds.similarityFunction); GraphIndexBuilder builder = new GraphIndexBuilder(bsp, @@ -393,7 +391,7 @@ private static Map, GraphIndex> buildInMemory(List runAllAndCollectResults( var searchCompressorObj = getCompressor(searchCompressor, ds); CompressedVectors cvArg = (searchCompressorObj instanceof CompressedVectors) ? (CompressedVectors) searchCompressorObj : null; var indexes = buildOnDisk(List.of(features), m, ef, neighborOverflow, addHierarchy, false, ds, testDirectory, compressor); - GraphIndex index = indexes.get(features); + ImmutableGraphIndex index = indexes.get(features); try (ConfiguredSystem cs = new ConfiguredSystem(ds, index, cvArg, features)) { int queryRuns = 2; List benchmarks = List.of( @@ -659,7 +657,7 @@ private static VectorCompressor getCompressor(Function features; @@ -667,20 +665,20 @@ public static class ConfiguredSystem implements AutoCloseable { return new GraphSearcher(index); }); - ConfiguredSystem(DataSet ds, GraphIndex index, CompressedVectors cv, Set features) { + ConfiguredSystem(DataSet ds, ImmutableGraphIndex index, CompressedVectors cv, Set features) { this.ds = ds; this.index = index; this.cv = cv; this.features = features; } - public SearchScoreProvider scoreProviderFor(VectorFloat queryVector, GraphIndex.View view) { + public SearchScoreProvider scoreProviderFor(VectorFloat queryVector, ImmutableGraphIndex.View view) { // if we're not compressing then just use the exact score function if (cv == null) { return DefaultSearchScoreProvider.exact(queryVector, ds.similarityFunction, ds.getBaseRavv()); } - var scoringView = (GraphIndex.ScoringView) view; + var scoringView = (ImmutableGraphIndex.ScoringView) view; ScoreFunction.ApproximateScoreFunction asf; if (features.contains(FeatureId.FUSED_ADC)) { asf = scoringView.approximateScoreFunctionFor(queryVector, ds.similarityFunction); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java index 21c3e1814..3c125fd2b 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java @@ -18,16 +18,14 @@ import io.github.jbellis.jvector.example.util.MMapRandomAccessVectorValues; import io.github.jbellis.jvector.example.util.UpdatableRandomAccessVectorValues; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.GraphIndexBuilder; import io.github.jbellis.jvector.graph.GraphSearcher; -import io.github.jbellis.jvector.graph.OnHeapGraphIndex; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.SearchResult; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; -import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; import io.github.jbellis.jvector.quantization.CompressedVectors; import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.util.Bits; @@ -74,7 +72,7 @@ static class SessionContext { RandomAccessVectorValues ravv; CompressedVectors cv; GraphIndexBuilder indexBuilder; - GraphIndex index; + ImmutableGraphIndex index; GraphSearcher searcher; final StringBuffer result = new StringBuffer(1024); } @@ -191,13 +189,13 @@ private CompressedVectors pqIndex(RandomAccessVectorValues ravv, SessionContext return cv; } - private static GraphIndex flushGraphIndex(OnHeapGraphIndex onHeapIndex, RandomAccessVectorValues ravv) { + private static ImmutableGraphIndex flushGraphIndex(ImmutableGraphIndex index, RandomAccessVectorValues ravv) { try { var testDirectory = Files.createTempDirectory("BenchGraphDir"); var graphPath = testDirectory.resolve("graph.bin"); - OnDiskGraphIndex.write(onHeapIndex, ravv, graphPath); - return onHeapIndex; + OnDiskGraphIndex.write(index, ravv, graphPath); + return index; } catch (IOException e) { throw new IOError(e); } @@ -263,8 +261,8 @@ String search(String input, SessionContext ctx) { if (ctx.cv != null) { ScoreFunction.ApproximateScoreFunction sf = ctx.cv.precomputedScoreFunctionFor(queryVector, ctx.similarityFunction); try (var view = ctx.index.getView()) { - var rr = view instanceof GraphIndex.ScoringView - ? ((GraphIndex.ScoringView) view).rerankerFor(queryVector, ctx.similarityFunction) + var rr = view instanceof ImmutableGraphIndex.ScoringView + ? ((ImmutableGraphIndex.ScoringView) view).rerankerFor(queryVector, ctx.similarityFunction) : ctx.ravv.rerankerFor(queryVector, ctx.similarityFunction); var ssp = new DefaultSearchScoreProvider(sf, rr); r = new GraphSearcher(ctx.index).search(ssp, searchEf, Bits.ALL); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java index fa57b1066..e0785e28b 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java @@ -21,11 +21,10 @@ import io.github.jbellis.jvector.disk.ReaderSupplierFactory; import io.github.jbellis.jvector.example.util.AccuracyMetrics; import io.github.jbellis.jvector.example.util.SiftLoader; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.GraphIndexBuilder; import io.github.jbellis.jvector.graph.GraphSearcher; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; -import io.github.jbellis.jvector.graph.OnHeapGraphIndex; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.SearchResult; import io.github.jbellis.jvector.graph.disk.feature.Feature; @@ -90,7 +89,7 @@ public static void siftInMemory(List> baseVectors) throws IOExcep true)) { // build the index (in memory) - OnHeapGraphIndex index = builder.build(ravv); + ImmutableGraphIndex index = builder.build(ravv); // search for a random vector VectorFloat q = randomVector(originalDimension); @@ -113,7 +112,7 @@ public static void siftInMemoryWithSearcher(List> baseVectors) th BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN); try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f, false, true)) { - OnHeapGraphIndex index = builder.build(ravv); + ImmutableGraphIndex index = builder.build(ravv); // search for a random vector using a GraphSearcher and SearchScoreProvider VectorFloat q = randomVector(originalDimension); @@ -134,7 +133,7 @@ public static void siftInMemoryWithRecall(List> baseVectors, List BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN); try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f, false, true)) { - OnHeapGraphIndex index = builder.build(ravv); + ImmutableGraphIndex index = builder.build(ravv); // measure our recall against the (exactly computed) ground truth Function, SearchScoreProvider> sspFactory = q -> DefaultSearchScoreProvider.exact(q, VectorSimilarityFunction.EUCLIDEAN, ravv); testRecall(index, queryVectors, groundTruth, sspFactory); @@ -150,7 +149,7 @@ public static void siftPersisted(List> baseVectors, List> baseVectors, List randomVector(int dim) { return vec; } - private static void testRecall(GraphIndex graph, + private static void testRecall(ImmutableGraphIndex graph, List> queryVectors, List> groundTruth, Function, diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index a97e4858f..21de0fede 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -17,10 +17,9 @@ package io.github.jbellis.jvector; import io.github.jbellis.jvector.disk.BufferedRandomAccessWriter; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.GraphIndexBuilder; import io.github.jbellis.jvector.graph.NodesIterator; -import io.github.jbellis.jvector.graph.OnHeapGraphIndex; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.disk.CommonHeader; import io.github.jbellis.jvector.graph.disk.feature.Feature; @@ -35,6 +34,7 @@ import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.apache.commons.lang3.NotImplementedException; import java.io.BufferedOutputStream; import java.io.DataOutputStream; @@ -147,11 +147,11 @@ public static List> createNormalRandomVectors(int count, int dime return IntStream.range(0, count).mapToObj(i -> TestUtil.normalRandomVector(getRandom(), dimension)).collect(Collectors.toList()); } - public static void writeGraph(GraphIndex graph, RandomAccessVectorValues ravv, Path outputPath) throws IOException { + public static void writeGraph(ImmutableGraphIndex graph, RandomAccessVectorValues ravv, Path outputPath) throws IOException { OnDiskGraphIndex.write(graph, ravv, outputPath); } - public static void writeFusedGraph(GraphIndex graph, RandomAccessVectorValues ravv, PQVectors pqv, Path outputPath) throws IOException { + public static void writeFusedGraph(ImmutableGraphIndex graph, RandomAccessVectorValues ravv, PQVectors pqv, Path outputPath) throws IOException { try (var writer = new OnDiskGraphIndexWriter.Builder(graph, outputPath) .with(new InlineVectors(ravv.dimension())) .with(new FusedADC(graph.maxDegree(), pqv.getCompressor())).build()) @@ -163,7 +163,7 @@ public static void writeFusedGraph(GraphIndex graph, RandomAccessVectorValues ra } } - public static Set getNeighborNodes(GraphIndex.View g, int level, int node) { + public static Set getNeighborNodes(ImmutableGraphIndex.View g, int level, int node) { Set neighbors = new HashSet<>(); for (var it = g.getNeighborsIterator(level, node); it.hasNext(); ) { int n = it.nextInt(); @@ -172,7 +172,7 @@ public static Set getNeighborNodes(GraphIndex.View g, int level, int no return neighbors; } - static List sortedNodes(GraphIndex h, int level) { + static List sortedNodes(ImmutableGraphIndex h, int level) { var graphNodes = h.getNodes(level); // TODO List nodes = new ArrayList<>(); while (graphNodes.hasNext()) { @@ -182,15 +182,15 @@ static List sortedNodes(GraphIndex h, int level) { return nodes; } - public static void assertGraphEquals(GraphIndex g, GraphIndex h) { + public static void assertGraphEquals(ImmutableGraphIndex g, ImmutableGraphIndex h) { // construct these up front since they call seek which will mess up our test loop - String prettyG = GraphIndex.prettyPrint(g); - String prettyH = GraphIndex.prettyPrint(h); + String prettyG = ImmutableGraphIndex.prettyPrint(g); + String prettyH = ImmutableGraphIndex.prettyPrint(h); assertEquals(String.format("the number of nodes in the graphs are different:%n%s%n%s", prettyG, prettyH), - g.size(), - h.size()); + g.size(0), + h.size(0)); assertEquals(g.getView().entryNode(), h.getView().entryNode()); for (int level = 0; level <= g.getMaxLevel(); level++) { @@ -228,7 +228,7 @@ public static void assertEqualsLazy(Supplier f, Set s1, Set layerSizes; @@ -259,6 +259,21 @@ public int maxDegree() { return layerSizes.stream().mapToInt(i -> i).max().orElseThrow(); } + @Override + public List maxDegrees() { + throw new NotImplementedException(); + } + + @Override + public int getIdUpperBound() { + return ImmutableGraphIndex.super.getIdUpperBound(); + } + + @Override + public boolean containsNode(int nodeId) { + return ImmutableGraphIndex.super.containsNode(nodeId); + } + @Override public NodesIterator getNodes(int level) { int n = layerSizes.get(level); @@ -275,6 +290,11 @@ public int getDegree(int level) { return layerSizes.get(level) - 1; } + @Override + public double getAverageDegree(int level) { + throw new NotImplementedException(); + } + @Override public int getMaxLevel() { return layerSizes.size() - 1; @@ -291,6 +311,7 @@ public NodesIterator getNeighborsIterator(int level, int node) { layerSizes.get(level) - 1); } + @Deprecated @Override public int size() { return FullyConnectedGraphIndex.this.size(0); @@ -308,6 +329,11 @@ public Bits liveNodes() { @Override public void close() { } + + @Override + public boolean contains(int level, int node) { + return node < layerSizes.get(level); + } } @Override @@ -316,7 +342,7 @@ public long ramBytesUsed() { } } - public static class RandomlyConnectedGraphIndex implements GraphIndex { + public static class RandomlyConnectedGraphIndex implements ImmutableGraphIndex { private final List layerInfo; private final List> layerAdjacency; private final int entryNode; @@ -383,11 +409,31 @@ public int getDegree(int level) { return layerInfo.get(level).degree; } + @Override + public double getAverageDegree(int level) { + throw new NotImplementedException(); + } + @Override public int maxDegree() { return layerInfo.stream().mapToInt(li -> li.degree).max().orElseThrow(); } + @Override + public List maxDegrees() { + throw new NotImplementedException(); + } + + @Override + public int getIdUpperBound() { + return ImmutableGraphIndex.super.getIdUpperBound(); + } + + @Override + public boolean containsNode(int nodeId) { + return ImmutableGraphIndex.super.containsNode(nodeId); + } + @Override public void close() { } @@ -398,6 +444,8 @@ public NodesIterator getNeighborsIterator(int level, int node) { return new NodesIterator.ArrayNodesIterator(adjacency.get(node)); } + @Deprecated + @Override public int size() { return layerInfo.get(0).size; } @@ -412,6 +460,11 @@ public Bits liveNodes() { return Bits.ALL; } + @Override + public boolean contains(int level, int node) { + return layerAdjacency.get(level).containsKey(node); + } + @Override public void close() { } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java index a6ce86532..716621d21 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java @@ -94,7 +94,9 @@ public void testRescore(boolean addHierarchy) { builder.addGraphNode(0, ravv.getVector(0)); builder.addGraphNode(1, ravv.getVector(1)); builder.addGraphNode(2, ravv.getVector(2)); - var neighbors = builder.graph.getNeighbors(0, 0); + + var ohgi = (OnHeapGraphIndex) builder.graph; + var neighbors = ohgi.getNeighbors(0, 0); assertEquals(1, neighbors.getNode(0)); assertEquals(2, neighbors.getNode(1)); assertEquals(0.5f, neighbors.getScore(0), 1E-6); @@ -111,7 +113,7 @@ public void testRescore(boolean addHierarchy) { var rescored = GraphIndexBuilder.rescore(builder, bsp); // Verify edges still exist - var newGraph = rescored.getGraph(); + var newGraph = (OnHeapGraphIndex) rescored.getGraph(); assertTrue(newGraph.containsNode(0)); assertTrue(newGraph.containsNode(1)); assertTrue(newGraph.containsNode(2)); @@ -140,7 +142,7 @@ public void testSaveAndLoad() throws IOException { var graph = TestUtil.buildSequentially(builder, ravv); try (var out = TestUtil.openDataOutputStream(indexDataPath)) { - graph.save(out); + ((OnHeapGraphIndex) graph).save(out); } builder = newBuilder.get(); @@ -148,7 +150,7 @@ public void testSaveAndLoad() throws IOException { builder.load(readerSupplier.get()); } - assertEquals(ravv.size(), builder.graph.size()); + assertEquals(ravv.size(), builder.graph.size(0)); for (int i = 0; i < ravv.size(); i++) { assertTrue(builder.graph.containsNode(i)); } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java index 56a35fe42..da052a617 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java @@ -65,7 +65,7 @@ public void testMarkDeleted(boolean addHierarchy) { // check that asking for the entire graph back still doesn't surface the deleted one var v = ravv.getVector(n).copy(); var results = GraphSearcher.search(v, ravv.size(), ravv, VectorSimilarityFunction.COSINE, graph, Bits.ALL); - assertEquals(GraphIndex.prettyPrint(graph), ravv.size() - 1, results.getNodes().length); + assertEquals(ImmutableGraphIndex.prettyPrint(graph), ravv.size() - 1, results.getNodes().length); for (var ns : results.getNodes()) { assertNotEquals(n, ns.node); } @@ -88,7 +88,7 @@ public void testCleanup(boolean addHierarchy) throws IOException { int nodeToIsolate = getRandom().nextInt(ravv.size()); int nDeleted = 0; try (var view = graph.getView()) { - for (var i = 0; i < view.size(); i++) { + for (var i = 0; i < graph.size(0); i++) { for (var it = view.getNeighborsIterator(0, i); it.hasNext(); ) { // TODO hardcoded level if (nodeToIsolate == it.nextInt()) { builder.markNodeDeleted(i); @@ -102,18 +102,20 @@ public void testCleanup(boolean addHierarchy) throws IOException { // cleanup removes the deleted nodes builder.cleanup(); - assertEquals(ravv.size() - nDeleted, graph.size()); + assertEquals(ravv.size() - nDeleted, graph.size(0)); // cleanup should have added new connections to the node that would otherwise have been disconnected var v = ravv.getVector(nodeToIsolate).copy(); var results = GraphSearcher.search(v, 10, ravv, VectorSimilarityFunction.COSINE, graph, Bits.ALL); assertEquals(nodeToIsolate, results.getNodes()[0].node); + var ohgi = (OnHeapGraphIndex) graph; + // check that we can save and load the graph with "holes" from the deletion var testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); var outputPath = testDirectory.resolve("on_heap_graph"); try (var out = TestUtil.openDataOutputStream(outputPath)) { - graph.save(out); + ohgi.save(out); } var b2 = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 4, 10, 1.0f, 1.0f, addHierarchy); @@ -138,14 +140,14 @@ public void testMarkingAllNodesAsDeleted(boolean addHierarchy) { var graph = TestUtil.buildSequentially(builder, ravv); // mark all deleted - for (var i = 0; i < graph.size(); i++) { - graph.markDeleted(i); + for (var i = 0; i < graph.size(0); i++) { + builder.markNodeDeleted(i); } // removeDeletedNodes should leave the graph empty builder.removeDeletedNodes(); - assertEquals(0, graph.size()); - assertNull(graph.entry()); + assertEquals(0, graph.size(0)); + assertNull(graph.getView().entryNode()); } @Test diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java index bcb413442..52bdc872a 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java @@ -29,9 +29,6 @@ import io.github.jbellis.jvector.LuceneTestCase; import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; -import io.github.jbellis.jvector.graph.similarity.ScoreFunction; -import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; -import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.BoundedLongHeap; @@ -43,9 +40,9 @@ import org.junit.Before; import org.junit.Test; +import java.io.IOException; import java.util.Arrays; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.Set; @@ -280,30 +277,33 @@ public int size() { } // We expect to get approximately 100% recall; // the lowest docIds are closest to zero; sum(0,9) = 45 - assertTrue("sum(result docs)=" + sum + " for " + GraphIndex.prettyPrint(builder.graph), sum < 75); + assertTrue("sum(result docs)=" + sum + " for " + ImmutableGraphIndex.prettyPrint(builder.graph), sum < 75); } - private static void validateIndex(OnHeapGraphIndex graph) { + private static void validateIndex(ImmutableGraphIndex graph) { + var view = graph.getView(); + for (int level = graph.getMaxLevel(); level > 0; level--) { for (var nodeIt = graph.getNodes(level); nodeIt.hasNext(); ) { var nodeInLevel = nodeIt.nextInt(); // node's neighbors should also exist in the same level - var neighbors = graph.getNeighbors(level, nodeInLevel); - for (int neighbor : neighbors.copyDenseNodes()) { - assertNotNull(graph.getNeighbors(level, neighbor)); + var it = view.getNeighborsIterator(level, nodeInLevel); + while (it.hasNext()) { + int neighbor = it.nextInt(); + assertTrue(view.contains(level, neighbor)); } // node should exist at every layer below it for (int lowerLevel = level - 1; lowerLevel >= 0; lowerLevel--) { - assertNotNull(graph.getNeighbors(lowerLevel, nodeInLevel)); + assertTrue(view.contains(lowerLevel, nodeInLevel)); } } } // no holes in lowest level (not true for all graphs but true for the ones constructed here) for (int i = 0; i < graph.getIdUpperBound(); i++) { - assertNotNull(graph.getNeighbors(0, i)); + assertTrue(view.getNeighborsIterator(0, i).size() >= 0); } } @@ -343,7 +343,7 @@ public void testAknnDiverse(boolean addHierarchy) { } // We expect to get approximately 100% recall; // the lowest docIds are closest to zero; sum(0,9) = 45 - assertTrue("sum(result docs)=" + sum + " for " + GraphIndex.prettyPrint(builder.graph), sum < 75); + assertTrue("sum(result docs)=" + sum + " for " + ImmutableGraphIndex.prettyPrint(builder.graph), sum < 75); } @Test @@ -378,7 +378,7 @@ public void testSearchWithAcceptOrds(boolean addHierarchy) { } // We expect to get approximately 100% recall; // the lowest docIds are closest to zero; sum(0,9) = 45 - assertTrue("sum(result docs)=" + sum + " for " + GraphIndex.prettyPrint(builder.graph), sum < 75); + assertTrue("sum(result docs)=" + sum + " for " + ImmutableGraphIndex.prettyPrint(builder.graph), sum < 75); } @Test @@ -414,13 +414,13 @@ public void testSearchWithSelectiveAcceptOrds(boolean addHierarchy) { int[] nodes = Arrays.stream(nn).mapToInt(nodeScore -> nodeScore.node).toArray(); for (int node : nodes) { assertTrue(String.format("the results include a deleted document: %d for %s", - node, GraphIndex.prettyPrint(builder.graph)), acceptOrds.get(node)); + node, ImmutableGraphIndex.prettyPrint(builder.graph)), acceptOrds.get(node)); } for (int i = 0; i < acceptOrds.length(); i++) { if (acceptOrds.get(i)) { int finalI = i; assertTrue(String.format("the results do not include an accepted document: %d for %s", - i, GraphIndex.prettyPrint(builder.graph)), Arrays.stream(nodes).anyMatch(j -> j == finalI)); + i, ImmutableGraphIndex.prettyPrint(builder.graph)), Arrays.stream(nodes).anyMatch(j -> j == finalI)); } } } @@ -479,37 +479,49 @@ public void testDiversity(boolean addHierarchy) { builder.addGraphNode(0, vectors.getVector(0)); builder.addGraphNode(1, vectors.getVector(1)); builder.addGraphNode(2, vectors.getVector(2)); + + var view = builder.graph.getView(); + // now every node has tried to attach every other node as a neighbor, but // some were excluded based on diversity check. - assertNeighbors(builder.graph, 0, 1, 2); - assertNeighbors(builder.graph, 1, 0); - assertNeighbors(builder.graph, 2, 0); + assertNeighbors(view, 0, 1, 2); + assertNeighbors(view, 1, 0); + assertNeighbors(view, 2, 0); builder.addGraphNode(3, vectors.getVector(3)); - assertNeighbors(builder.graph, 0, 1, 2); + + view = builder.graph.getView(); + + assertNeighbors(view, 0, 1, 2); // we added 3 here - assertNeighbors(builder.graph, 1, 0, 3); - assertNeighbors(builder.graph, 2, 0); - assertNeighbors(builder.graph, 3, 1); + assertNeighbors(view, 1, 0, 3); + assertNeighbors(view, 2, 0); + assertNeighbors(view, 3, 1); // supplant an existing neighbor builder.addGraphNode(4, vectors.getVector(4)); + + view = builder.graph.getView(); + // 4 is the same distance from 0 that 2 is; we leave the existing node in place - assertNeighbors(builder.graph, 0, 1, 2); - assertNeighbors(builder.graph, 1, 0, 3, 4); - assertNeighbors(builder.graph, 2, 0); + assertNeighbors(view, 0, 1, 2); + assertNeighbors(view, 1, 0, 3, 4); + assertNeighbors(view, 2, 0); // 1 survives the diversity check - assertNeighbors(builder.graph, 3, 1, 4); - assertNeighbors(builder.graph, 4, 1, 3); + assertNeighbors(view, 3, 1, 4); + assertNeighbors(view, 4, 1, 3); builder.addGraphNode(5, vectors.getVector(5)); - assertNeighbors(builder.graph, 0, 1, 2); - assertNeighbors(builder.graph, 1, 0, 3, 4, 5); - assertNeighbors(builder.graph, 2, 0); + + view = builder.graph.getView(); + + assertNeighbors(view, 0, 1, 2); + assertNeighbors(view, 1, 0, 3, 4, 5); + assertNeighbors(view, 2, 0); // even though 5 is closer, 3 is not a neighbor of 5, so no update to *its* neighbors occurs - assertNeighbors(builder.graph, 3, 1, 4); - assertNeighbors(builder.graph, 4, 1, 3, 5); - assertNeighbors(builder.graph, 5, 1, 4); + assertNeighbors(view, 3, 1, 4); + assertNeighbors(view, 4, 1, 3, 5); + assertNeighbors(view, 5, 1, 4); } @Test @@ -538,18 +550,24 @@ public void testDiversityFallback(boolean addHierarchy) { builder.addGraphNode(0, vectors.getVector(0)); builder.addGraphNode(1, vectors.getVector(1)); builder.addGraphNode(2, vectors.getVector(2)); - assertNeighbors(builder.graph, 0, 1, 2); + + var view = builder.graph.getView(); + + assertNeighbors(view, 0, 1, 2); // 2 is closer to 0 than 1, so it is excluded as non-diverse - assertNeighbors(builder.graph, 1, 0); + assertNeighbors(view, 1, 0); // 1 is closer to 0 than 2, so it is excluded as non-diverse - assertNeighbors(builder.graph, 2, 0); + assertNeighbors(view, 2, 0); builder.addGraphNode(3, vectors.getVector(3)); + + view = builder.graph.getView(); + // this is one case we are testing; 2 has been displaced by 3 - assertNeighbors(builder.graph, 0, 1, 3); - assertNeighbors(builder.graph, 1, 0); - assertNeighbors(builder.graph, 2, 0); - assertNeighbors(builder.graph, 3, 0); + assertNeighbors(view, 0, 1, 3); + assertNeighbors(view, 1, 0); + assertNeighbors(view, 2, 0); + assertNeighbors(view, 3, 0); } @Test @@ -574,21 +592,27 @@ public void testDiversity3d(boolean addHierarchy) { builder.addGraphNode(0, vectors.getVector(0)); builder.addGraphNode(1, vectors.getVector(1)); builder.addGraphNode(2, vectors.getVector(2)); - assertNeighbors(builder.graph, 0, 1, 2); + + var view = builder.graph.getView(); + + assertNeighbors(view, 0, 1, 2); // 2 is closer to 0 than 1, so it is excluded as non-diverse - assertNeighbors(builder.graph, 1, 0); + assertNeighbors(view, 1, 0); // 1 is closer to 0 than 2, so it is excluded as non-diverse - assertNeighbors(builder.graph, 2, 0); + assertNeighbors(view, 2, 0); builder.addGraphNode(3, vectors.getVector(3)); + + view = builder.graph.getView(); + // this is one case we are testing; 1 has been displaced by 3 - assertNeighbors(builder.graph, 0, 2, 3); - assertNeighbors(builder.graph, 1, 0, 3); - assertNeighbors(builder.graph, 2, 0); - assertNeighbors(builder.graph, 3, 0, 1); + assertNeighbors(view, 0, 2, 3); + assertNeighbors(view, 1, 0, 3); + assertNeighbors(view, 2, 0); + assertNeighbors(view, 3, 0, 1); } - private void assertNeighbors(OnHeapGraphIndex graph, int node, int... expected) { + private void assertNeighbors(ImmutableGraphIndex.View graph, int node, int... expected) { Arrays.sort(expected); NodesIterator it = graph.getNeighborsIterator(0, node); int[] actual = new int[it.size()]; @@ -677,8 +701,9 @@ public void testConcurrentNeighbors(boolean addHierarchy) { GraphIndexBuilder builder = new GraphIndexBuilder(vectors, similarityFunction, 2, 30, 1.0f, 1.4f, addHierarchy); var graph = builder.build(vectors); validateIndex(graph); + var view = graph.getView(); for (int i = 0; i < vectors.size(); i++) { - assertTrue(graph.getNeighbors(0, i).size() <= 2); // TODO + assertTrue(view.getNeighborsIterator(0, i).size() <= 2); // TODO } } @@ -699,6 +724,8 @@ public void testZeroCentroid(boolean addHierarchy) { var results = GraphSearcher.search(qv, 1, vectors, VectorSimilarityFunction.COSINE, graph, Bits.ALL); assertEquals(1, results.getNodes().length); assertEquals(1, results.getNodes()[0].node); + } catch (IOException e) { + throw new RuntimeException(e); } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java index 26769709e..c1eaf4ec3 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java @@ -78,7 +78,7 @@ public void testSimpleGraphs() throws Exception { for (var graph : List.of(fullyConnectedGraph, randomlyConnectedGraph)) { var outputPath = testDirectory.resolve("test_graph_" + graph.getClass().getSimpleName()); - var ravv = new TestVectorGraph.CircularFloatVectorValues(graph.size()); + var ravv = new TestVectorGraph.CircularFloatVectorValues(graph.size(0)); TestUtil.writeGraph(graph, ravv, outputPath); try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath.toAbsolutePath()); var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) @@ -109,7 +109,7 @@ public void testRenumberingOnDelete(boolean addHierarchy) throws IOException { builder.setEntryPoint(0, builder.getGraph().getIdUpperBound() - 1); // TODO // check - assertEquals(2, original.size()); + assertEquals(2, original.size(0)); var originalView = original.getView(); // 1 -> 2 assertEquals(1, getNeighborNodes(originalView, 0, 1).size()); @@ -244,7 +244,7 @@ private static void validateSeparatedNVQ(OnDiskGraphIndex.View view, public void testSimpleGraphSeparated() throws Exception { for (var graph : List.of(fullyConnectedGraph, randomlyConnectedGraph)) { var outputPath = testDirectory.resolve("test_graph_separated_" + graph.getClass().getSimpleName()); - var ravv = new TestVectorGraph.CircularFloatVectorValues(graph.size()); + var ravv = new TestVectorGraph.CircularFloatVectorValues(graph.size(0)); // Write graph with SEPARATED_VECTORS try (var writer = new OnDiskGraphIndexWriter.Builder(graph, outputPath) @@ -315,7 +315,7 @@ public void testLargeGraph() throws Exception { var graph = new TestUtil.RandomlyConnectedGraphIndex(1_000_000, 32, getRandom()); var outputPath = testDirectory.resolve("large_graph"); - var ravv = new TestVectorGraph.CircularFloatVectorValues(graph.size()); + var ravv = new TestVectorGraph.CircularFloatVectorValues(graph.size(0)); TestUtil.writeGraph(graph, ravv, outputPath); try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath.toAbsolutePath()); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskSequentialGraphIndexWriter.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskSequentialGraphIndexWriter.java index baf134cf7..0fb6d7ecb 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskSequentialGraphIndexWriter.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskSequentialGraphIndexWriter.java @@ -20,7 +20,7 @@ import io.github.jbellis.jvector.LuceneTestCase; import io.github.jbellis.jvector.disk.SimpleMappedReader; import io.github.jbellis.jvector.disk.SimpleWriter; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.GraphIndexBuilder; import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; @@ -84,7 +84,7 @@ void buildAndCompareGraphs(int size, int dimension, int maxConnections, int beam // Create random vectors and build a graph var ravv = new ListRandomAccessVectorValues(new ArrayList<>(TestUtil.createRandomVectors(size, dimension)), dimension); var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, maxConnections, beamWidth, neighborOverflow, alpha, addHierarchy); - GraphIndex graph = TestUtil.buildSequentially(builder, ravv); + ImmutableGraphIndex graph = TestUtil.buildSequentially(builder, ravv); // Create a sequential writer and write the graph Path indexPath = testDirectory.resolve("graph.index_with_hierarchy_" + addHierarchy); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java index 18791ea64..d6aefbe58 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java @@ -73,7 +73,7 @@ public void testFusedGraph() throws Exception { var reranker = cachedOnDiskView.rerankerFor(queryVector, similarityFunction); for (int i = 0; i < 50; i++) { var fusedScoreFunction = cachedOnDiskView.approximateScoreFunctionFor(queryVector, similarityFunction); - var ordinal = getRandom().nextInt(graph.size()); + var ordinal = getRandom().nextInt(graph.size(0)); // first pass compares fused ADC's direct similarity to reranker's similarity, used for comparisons to a specific node var neighbors = cachedOnDiskView.getNeighborsIterator(0, ordinal); for (; neighbors.hasNext(); ) {