diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java index 891fda756..ad6d137a0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java @@ -351,7 +351,7 @@ public NeighborWithShortEdges(Neighbors neighbors, double shortEdges) { } } - private static class NeighborIterator implements NodesIterator { + public static class NeighborIterator implements NodesIterator { private final NodeArray neighbors; private int i; @@ -374,5 +374,9 @@ public boolean hasNext() { public int nextInt() { return neighbors.getNode(i++); } + + public NodeArray merge(NodeArray other) { + return NodeArray.merge(neighbors, other); + } } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 6c9960ba4..c4dacd18e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -20,6 +20,8 @@ import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.GraphIndex.NodeAtLevel; import io.github.jbellis.jvector.graph.SearchResult.NodeScore; +import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; @@ -38,10 +40,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.*; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentSkipListSet; -import java.util.concurrent.ForkJoinPool; -import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.IntStream; @@ -297,8 +296,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, - ForkJoinPool parallelExecutor) - { + ForkJoinPool parallelExecutor) { if (maxDegrees.stream().anyMatch(i -> i <= 0)) { throw new IllegalArgumentException("layer degrees must be positive"); } @@ -339,6 +337,50 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.rng = new Random(0); } + /** + * Create this builder from an existing {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex}, this is useful when we just loaded a graph from disk + * copy it into {@link OnHeapGraphIndex} and then start mutating it with minimal overhead of recreating the mutable {@link OnHeapGraphIndex} used in the new GraphIndexBuilder object + * + * @param buildScoreProvider the provider responsible for calculating build scores. + * @param onDiskGraphIndex the on-disk representation of the graph index to be processed and converted. + * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, + * organized by levels and nodes. + * @param beamWidth the width of the beam used during the graph building process. + * @param neighborOverflow the factor determining how many additional neighbors are allowed beyond the configured limit. + * @param alpha the weight factor for balancing score computations. + * @param addHierarchy whether to add hierarchical structures while building the graph. + * @param refineFinalGraph whether to perform a refinement step on the final graph structure. + * @param simdExecutor the ForkJoinPool executor used for SIMD tasks during graph building. + * @param parallelExecutor the ForkJoinPool executor used for general parallelization during graph building. + * + * @throws IOException if an I/O error occurs during the graph loading or conversion process. + */ + public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, OnDiskGraphIndex onDiskGraphIndex, NeighborsScoreCache perLevelNeighborsScoreCache, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) throws IOException { + this.scoreProvider = buildScoreProvider; + this.neighborOverflow = neighborOverflow; + this.dimension = onDiskGraphIndex.getDimension(); + this.alpha = alpha; + this.addHierarchy = addHierarchy; + this.refineFinalGraph = refineFinalGraph; + this.beamWidth = beamWidth; + this.simdExecutor = simdExecutor; + this.parallelExecutor = parallelExecutor; + + this.graph = OnHeapGraphIndex.convertToHeap(onDiskGraphIndex, perLevelNeighborsScoreCache, buildScoreProvider, neighborOverflow, alpha); + + this.searchers = ExplicitThreadLocal.withInitial(() -> { + var gs = new GraphSearcher(graph); + gs.usePruning(false); + return gs; + }); + + // in scratch, we store candidates in reverse order: worse candidates are first + this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1))); + this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1))); + + this.rng = new Random(0); + } + // used by Cassandra when it fine-tunes the PQ codebook public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) { var newBuilder = new GraphIndexBuilder(newProvider, @@ -740,6 +782,60 @@ public synchronized long removeDeletedNodes() { return memorySize; } + /** + * Convenience method to build a new graph from an existing one, with the addition of new nodes. + * This is useful when we want to merge a new set of vectors into an existing graph that is already on disk. + * + * @param onDiskGraphIndex the on-disk representation of the graph index to be processed and converted. + * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, + * @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph + * @param buildScoreProvider the provider responsible for calculating build scores. + * @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start + * @param graphToRavvOrdMap a mapping from the old graph's node ids to the newVectors RAVV node ids + * @param beamWidth the width of the beam used during the graph building process. + * @param overflowRatio the ratio of extra neighbors to allow temporarily when inserting a node. + * @param alpha the weight factor for balancing score computations. + * @param addHierarchy whether to add hierarchical structures while building the graph. + * + * @return the in-memory representation of the graph index. + * @throws IOException if an I/O error occurs during the graph loading or conversion process. + */ + public static OnHeapGraphIndex buildAndMergeNewNodes(OnDiskGraphIndex onDiskGraphIndex, + NeighborsScoreCache perLevelNeighborsScoreCache, + RandomAccessVectorValues newVectors, + BuildScoreProvider buildScoreProvider, + int startingNodeOffset, + int[] graphToRavvOrdMap, + int beamWidth, + float overflowRatio, + float alpha, + boolean addHierarchy) throws IOException { + + + + try (GraphIndexBuilder builder = new GraphIndexBuilder(buildScoreProvider, + onDiskGraphIndex, + perLevelNeighborsScoreCache, + beamWidth, + overflowRatio, + alpha, + addHierarchy, + true, + PhysicalCoreExecutor.pool(), + ForkJoinPool.commonPool())) { + + var vv = newVectors.threadLocalSupplier(); + + // parallel graph construction from the merge documents Ids + PhysicalCoreExecutor.pool().submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> { + builder.addGraphNode(ord, vv.get().getVector(graphToRavvOrdMap[ord])); + })).join(); + + builder.cleanup(); + return builder.getGraph(); + } + } + private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) { // if either natural or concurrent is empty, skip the merge NodeArray toMerge; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 69f3527f4..e868306f0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -25,14 +25,13 @@ package io.github.jbellis.jvector.graph; import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors; +import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.diversity.DiversityProvider; +import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; -import io.github.jbellis.jvector.util.Accountable; -import io.github.jbellis.jvector.util.Bits; -import io.github.jbellis.jvector.util.DenseIntMap; -import io.github.jbellis.jvector.util.RamUsageEstimator; -import io.github.jbellis.jvector.util.SparseIntMap; -import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet; +import io.github.jbellis.jvector.util.*; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; import org.agrona.collections.IntArrayList; @@ -41,9 +40,12 @@ import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinTask; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.concurrent.atomic.AtomicReference; @@ -267,7 +269,7 @@ public ConcurrentGraphIndexView getView() { /** * A View that assumes no concurrent modifications are made */ - public GraphIndex.View getFrozenView() { + public FrozenView getFrozenView() { return new FrozenView(); } @@ -421,7 +423,7 @@ public boolean hasNext() { } } - private class FrozenView implements View { + public class FrozenView implements View { @Override public NodesIterator getNeighborsIterator(int level, int node) { return OnHeapGraphIndex.this.getNeighborsIterator(level, node); @@ -572,4 +574,68 @@ private void ensureCapacity(int node) { } } } + + /** + * Converts an OnDiskGraphIndex to an OnHeapGraphIndex by copying all nodes, their levels, and neighbors, + * along with other configuration details, from disk-based storage to heap-based storage. + * + * @param diskIndex the disk-based index to be converted + * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, + * organized by levels and nodes. + * @param bsp The build score provider to be used for + * @param overflowRatio usually 1.2f + * @param alpha usually 1.2f + * @return an OnHeapGraphIndex that is equivalent to the provided OnDiskGraphIndex but operates in heap memory + * @throws IOException if an I/O error occurs during the conversion process + */ + public static OnHeapGraphIndex convertToHeap(OnDiskGraphIndex diskIndex, + NeighborsScoreCache perLevelNeighborsScoreCache, + BuildScoreProvider bsp, + float overflowRatio, + float alpha) throws IOException { + + // Create a new OnHeapGraphIndex with the appropriate configuration + List maxDegrees = new ArrayList<>(); + for (int level = 0; level <= diskIndex.getMaxLevel(); level++) { + maxDegrees.add(diskIndex.getDegree(level)); + } + + OnHeapGraphIndex heapIndex = new OnHeapGraphIndex( + maxDegrees, + overflowRatio, // overflow ratio + new VamanaDiversityProvider(bsp, alpha) // diversity provider - can be null for basic usage + ); + + // Copy all nodes and their connections from disk to heap + try (var view = diskIndex.getView()) { + // Copy nodes level by level + for (int level = 0; level <= diskIndex.getMaxLevel(); level++) { + final NodesIterator nodesIterator = diskIndex.getNodes(level); + final Map levelNeighborsScoreCache = perLevelNeighborsScoreCache.getNeighborsScoresInLevel(level); + if (levelNeighborsScoreCache == null) { + throw new IllegalStateException("No neighbors score cache found for level " + level); + } + if (nodesIterator.size() != levelNeighborsScoreCache.size()) { + throw new IllegalStateException("Neighbors score cache size mismatch for level " + level + + ". Expected (currently in index): " + nodesIterator.size() + ", but got (in cache): " + levelNeighborsScoreCache.size()); + } + + while (nodesIterator.hasNext()) { + int nodeId = nodesIterator.next(); + + // Copy neighbors + final NodeArray neighbors = levelNeighborsScoreCache.get(nodeId).copy(); + + // Add the node with its neighbors + heapIndex.addNode(level, nodeId, neighbors); + heapIndex.markComplete(new NodeAtLevel(level, nodeId)); + } + } + + // Set the entry point + heapIndex.updateEntryNode(view.entryNode()); + } + + return heapIndex; + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java new file mode 100644 index 000000000..55fdcf082 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java @@ -0,0 +1,117 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import io.github.jbellis.jvector.disk.IndexWriter; +import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.*; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Cache containing pre-computed neighbor scores, organized by levels and nodes. + *

+ * This cache bridges the gap between {@link OnDiskGraphIndex} and {@link OnHeapGraphIndex}: + *

+ *

+ * When converting from disk to heap representation, this cache avoids expensive score + * recomputation by providing pre-calculated neighbor scores for all graph levels. + * + * @see OnHeapGraphIndex#convertToHeap(OnDiskGraphIndex, NeighborsScoreCache, BuildScoreProvider, float, float) + * + * This is particularly useful when merging new nodes into an existing graph. + * @see GraphIndexBuilder#buildAndMergeNewNodes(OnDiskGraphIndex, NeighborsScoreCache, RandomAccessVectorValues, BuildScoreProvider, int, int[], int, float, float, boolean) + */ +public class NeighborsScoreCache { + private final Map> perLevelNeighborsScoreCache; + + public NeighborsScoreCache(OnHeapGraphIndex graphIndex) throws IOException { + try (OnHeapGraphIndex.FrozenView view = graphIndex.getFrozenView()) { + final Map> perLevelNeighborsScoreCache = new HashMap<>(graphIndex.getMaxLevel() + 1); + for (int level = 0; level <= graphIndex.getMaxLevel(); level++) { + final Map levelNeighborsScores = new HashMap<>(graphIndex.size(level) + 1); + final NodesIterator nodesIterator = graphIndex.getNodes(level); + while (nodesIterator.hasNext()) { + final int nodeId = nodesIterator.nextInt(); + + ConcurrentNeighborMap.NeighborIterator neighborIterator = (ConcurrentNeighborMap.NeighborIterator) view.getNeighborsIterator(level, nodeId); + final NodeArray neighbours = neighborIterator.merge(new NodeArray(neighborIterator.size())); + levelNeighborsScores.put(nodeId, neighbours); + } + + perLevelNeighborsScoreCache.put(level, levelNeighborsScores); + } + + this.perLevelNeighborsScoreCache = perLevelNeighborsScoreCache; + } + } + + public NeighborsScoreCache(RandomAccessReader in) throws IOException { + final int numberOfLevels = in.readInt(); + perLevelNeighborsScoreCache = new HashMap<>(numberOfLevels); + for (int i = 0; i < numberOfLevels; i++) { + final int level = in.readInt(); + final int numberOfNodesInLevel = in.readInt(); + final Map levelNeighborsScores = new HashMap<>(numberOfNodesInLevel); + for (int j = 0; j < numberOfNodesInLevel; j++) { + final int nodeId = in.readInt(); + final int numberOfNeighbors = in.readInt(); + final NodeArray nodeArray = new NodeArray(numberOfNeighbors); + for (int k = 0; k < numberOfNeighbors; k++) { + final int neighborNodeId = in.readInt(); + final float neighborScore = in.readFloat(); + nodeArray.insertSorted(neighborNodeId, neighborScore); + } + levelNeighborsScores.put(nodeId, nodeArray); + } + perLevelNeighborsScoreCache.put(level, levelNeighborsScores); + } + } + + public void write(IndexWriter out) throws IOException { + out.writeInt(perLevelNeighborsScoreCache.size()); // write the number of levels + for (Map.Entry> levelNeighborsScores : perLevelNeighborsScoreCache.entrySet()) { + final int level = levelNeighborsScores.getKey(); + out.writeInt(level); + out.writeInt(levelNeighborsScores.getValue().size()); // write the number of nodes in the level + // Write the neighborhoods for each node in the level + for (Map.Entry nodeArrayEntry : levelNeighborsScores.getValue().entrySet()) { + final int nodeId = nodeArrayEntry.getKey(); + out.writeInt(nodeId); + final NodeArray nodeArray = nodeArrayEntry.getValue(); + out.writeInt(nodeArray.size()); // write the number of neighbors for the node + // Write the nodeArray(neighbors) + for (int i = 0; i < nodeArray.size(); i++) { + out.writeInt(nodeArray.getNode(i)); + out.writeFloat(nodeArray.getScore(i)); + } + } + } + } + + public Map getNeighborsScoresInLevel(int level) { + return perLevelNeighborsScoreCache.get(level); + } + + +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index 6239ecc1b..1655599fe 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -19,9 +19,7 @@ import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.disk.ReaderSupplier; -import io.github.jbellis.jvector.graph.GraphIndex; -import io.github.jbellis.jvector.graph.NodesIterator; -import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.*; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; import io.github.jbellis.jvector.graph.disk.feature.FeatureSource; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java index b8ec5fa5f..4656a4fcc 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java @@ -25,6 +25,8 @@ import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import java.util.stream.IntStream; + /** * Encapsulates comparing node distances for GraphIndexBuilder. */ @@ -83,8 +85,17 @@ public interface BuildScoreProvider { /** * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. + * + * Helper method for the special case that mapping between graph node IDs and ravv ordinals is the identity function. */ static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) { + return randomAccessScoreProvider(ravv, IntStream.range(0, ravv.size()).toArray(), similarityFunction); + } + + /** + * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. + */ + static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) { // We need two sources of vectors in order to perform diversity check comparisons without // colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared. var vectors = ravv.threadLocalSupplier(); @@ -113,22 +124,22 @@ public VectorFloat approximateCentroid() { @Override public SearchScoreProvider searchProviderFor(VectorFloat vector) { var vc = vectorsCopy.get(); - return DefaultSearchScoreProvider.exact(vector, similarityFunction, vc); + return DefaultSearchScoreProvider.exact(vector, graphToRavvOrdMap, similarityFunction, vc); } @Override public SearchScoreProvider searchProviderFor(int node1) { RandomAccessVectorValues randomAccessVectorValues = vectors.get(); - var v = randomAccessVectorValues.getVector(node1); + var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]); return searchProviderFor(v); } @Override public SearchScoreProvider diversityProviderFor(int node1) { RandomAccessVectorValues randomAccessVectorValues = vectors.get(); - var v = randomAccessVectorValues.getVector(node1); + var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]); var vc = vectorsCopy.get(); - return DefaultSearchScoreProvider.exact(v, similarityFunction, vc); + return DefaultSearchScoreProvider.exact(v, graphToRavvOrdMap, similarityFunction, vc); } }; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java index 0754b39d7..de46762b2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java @@ -78,4 +78,20 @@ public float similarityTo(int node2) { }; return new DefaultSearchScoreProvider(sf); } + + /** + * A SearchScoreProvider for a single-pass search based on exact similarity. + * Generally only suitable when your RandomAccessVectorValues is entirely in-memory, + * e.g. during construction. + */ + public static DefaultSearchScoreProvider exact(VectorFloat v, int[] graphToRavvOrdMap ,VectorSimilarityFunction vsf, RandomAccessVectorValues ravv) { + // don't use ESF.reranker, we need thread safety here + var sf = new ScoreFunction.ExactScoreFunction() { + @Override + public float similarityTo(int node2) { + return vsf.compare(v, ravv.getVector(graphToRavvOrdMap[node2])); + } + }; + return new DefaultSearchScoreProvider(sf); + } } \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index f66a2c6e4..73e59b20f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -77,6 +77,8 @@ public static PQVectors load(RandomAccessReader in, long offset) throws IOExcept * Build a PQVectors instance from the given RandomAccessVectorValues. The vectors are encoded in parallel * and split into chunks to avoid exceeding the maximum array size. * + * This is a helper method for the special case where the ordinals mapping in the graph and the RAVV/PQVectors are the same. + * * @param pq the ProductQuantization to use * @param vectorCount the number of vectors to encode * @param ravv the RandomAccessVectorValues to encode @@ -84,6 +86,21 @@ public static PQVectors load(RandomAccessReader in, long offset) throws IOExcept * @return the PQVectors instance */ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vectorCount, RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) { + return encodeAndBuild(pq, vectorCount, IntStream.range(0, vectorCount).toArray(), ravv, simdExecutor); + } + + /** + * Build a PQVectors instance from the given RandomAccessVectorValues. The vectors are encoded in parallel + * and split into chunks to avoid exceeding the maximum array size. + * + * @param pq the ProductQuantization to use + * @param vectorCount the number of vectors to encode + * @param ravv the RandomAccessVectorValues to encode + * @param simdExecutor the ForkJoinPool to use for SIMD operations + * @param ordinalsMapping the graph ordinals to RAVV mapping + * @return the PQVectors instance + */ + public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vectorCount, int[] ordinalsMapping, RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) { int compressedDimension = pq.compressedVectorSize(); PQLayout layout = new PQLayout(vectorCount,compressedDimension); final ByteSequence[] chunks = new ByteSequence[layout.totalChunks]; @@ -98,13 +115,13 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect // The changes are concurrent, but because they are coordinated and do not overlap, we can use parallel streams // and then we are guaranteed safe publication because we join the thread after completion. var ravvCopy = ravv.threadLocalSupplier(); - simdExecutor.submit(() -> IntStream.range(0, ravv.size()) + simdExecutor.submit(() -> IntStream.range(0, ordinalsMapping.length) .parallel() .forEach(ordinal -> { // Retrieve the slice and mutate it. var localRavv = ravvCopy.get(); var slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount()); - var vector = localRavv.getVector(ordinal); + var vector = localRavv.getVector(ordinalsMapping[ordinal]); if (vector != null) pq.encodeTo(vector, slice); else diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java new file mode 100644 index 000000000..b4efb543d --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java @@ -0,0 +1,264 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.TestUtil; +import io.github.jbellis.jvector.disk.SimpleMappedReader; +import io.github.jbellis.jvector.disk.SimpleWriter; +import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.apache.logging.log4j.Logger; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class OnHeapGraphIndexTest extends RandomizedTest { + private final static Logger log = org.apache.logging.log4j.LogManager.getLogger(OnHeapGraphIndexTest.class); + private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + private static final int NUM_BASE_VECTORS = 100; + private static final int NUM_NEW_VECTORS = 100; + private static final int NUM_ALL_VECTORS = NUM_BASE_VECTORS + NUM_NEW_VECTORS; + private static final int DIMENSION = 16; + private static final int M = 8; + private static final int BEAM_WIDTH = 100; + private static final float ALPHA = 1.2f; + private static final float NEIGHBOR_OVERFLOW = 1.2f; + private static final boolean ADD_HIERARCHY = false; + private static final int TOP_K = 10; + + private Path testDirectory; + + private ArrayList> baseVectors; + private ArrayList> newVectors; + private ArrayList> allVectors; + private RandomAccessVectorValues baseVectorsRavv; + private RandomAccessVectorValues newVectorsRavv; + private RandomAccessVectorValues allVectorsRavv; + private VectorFloat queryVector; + private int[] groundTruthBaseVectors; + private int[] groundTruthAllVectors; + private BuildScoreProvider baseBuildScoreProvider; + private BuildScoreProvider newBuildScoreProvider; + private BuildScoreProvider allBuildScoreProvider; + private OnHeapGraphIndex baseGraphIndex; + private OnHeapGraphIndex newGraphIndex; + private OnHeapGraphIndex allGraphIndex; + + @Before + public void setup() throws IOException { + testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); + baseVectors = new ArrayList<>(NUM_BASE_VECTORS); + newVectors = new ArrayList<>(NUM_NEW_VECTORS); + allVectors = new ArrayList<>(NUM_ALL_VECTORS); + for (int i = 0; i < NUM_BASE_VECTORS; i++) { + VectorFloat vector = createRandomVector(DIMENSION); + baseVectors.add(vector); + allVectors.add(vector); + } + for (int i = 0; i < NUM_NEW_VECTORS; i++) { + VectorFloat vector = createRandomVector(DIMENSION); + newVectors.add(vector); + allVectors.add(vector); + } + + // wrap the raw vectors in a RandomAccessVectorValues + baseVectorsRavv = new ListRandomAccessVectorValues(baseVectors, DIMENSION); + newVectorsRavv = new ListRandomAccessVectorValues(newVectors, DIMENSION); + allVectorsRavv = new ListRandomAccessVectorValues(allVectors, DIMENSION); + + queryVector = createRandomVector(DIMENSION); + groundTruthBaseVectors = getGroundTruth(baseVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN); + groundTruthAllVectors = getGroundTruth(allVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN); + + // score provider using the raw, in-memory vectors + baseBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); + newBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(newVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); + allBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(allVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); + var baseGraphIndexBuilder = new GraphIndexBuilder(baseBuildScoreProvider, + baseVectorsRavv.dimension(), + M, // graph degree + BEAM_WIDTH, // construction search depth + NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor + ALPHA, // relax neighbor diversity requirement by this factor + ADD_HIERARCHY); // add the hierarchy + var allGraphIndexBuilder = new GraphIndexBuilder(allBuildScoreProvider, + allVectorsRavv.dimension(), + M, // graph degree + BEAM_WIDTH, // construction search depth + NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor + ALPHA, // relax neighbor diversity requirement by this factor + ADD_HIERARCHY); // add the hierarchy + + baseGraphIndex = baseGraphIndexBuilder.build(baseVectorsRavv); + allGraphIndex = allGraphIndexBuilder.build(allVectorsRavv); + } + + @After + public void tearDown() { + TestUtil.deleteQuietly(testDirectory); + } + + + /** + * Create an {@link OnHeapGraphIndex} persist it as a {@link OnDiskGraphIndex} and reconstruct back to a mutable {@link OnHeapGraphIndex} + * Make sure that both graphs are equivalent + * @throws IOException + */ + @Test + public void testReconstructionOfOnHeapGraphIndex() throws IOException { + var graphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); + var neighborsScoreCacheOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + NeighborsScoreCache.class.getSimpleName()); + log.info("Writing graph to {}", graphOutputPath); + TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, graphOutputPath); + + log.info("Writing neighbors score cache to {}", neighborsScoreCacheOutputPath); + final NeighborsScoreCache neighborsScoreCache = new NeighborsScoreCache(baseGraphIndex); + try (SimpleWriter writer = new SimpleWriter(neighborsScoreCacheOutputPath.toAbsolutePath())) { + neighborsScoreCache.write(writer); + } + + log.info("Reading neighbors score cache from {}", neighborsScoreCacheOutputPath); + final NeighborsScoreCache neighborsScoreCacheRead; + try (var readerSupplier = new SimpleMappedReader.Supplier(neighborsScoreCacheOutputPath.toAbsolutePath())) { + neighborsScoreCacheRead = new NeighborsScoreCache(readerSupplier.get()); + } + + try (var readerSupplier = new SimpleMappedReader.Supplier(graphOutputPath.toAbsolutePath()); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { + TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph); + try (var onDiskView = onDiskGraph.getView()) { + validateVectors(onDiskView, baseVectorsRavv); + } + + OnHeapGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex.convertToHeap(onDiskGraph, neighborsScoreCacheRead, baseBuildScoreProvider, NEIGHBOR_OVERFLOW, ALPHA); + TestUtil.assertGraphEquals(baseGraphIndex, reconstructedOnHeapGraphIndex); + TestUtil.assertGraphEquals(onDiskGraph, reconstructedOnHeapGraphIndex); + + } + } + + /** + * Create {@link OnDiskGraphIndex} then append to it via {@link GraphIndexBuilder#buildAndMergeNewNodes} + * Verify that the resulting OnHeapGraphIndex is equivalent to the graph that would have been alternatively generated by bulk index into a new {@link OnDiskGraphIndex} + */ + @Test + public void testIncrementalInsertionFromOnDiskIndex() throws IOException { + var outputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); + log.info("Writing graph to {}", outputPath); + final NeighborsScoreCache neighborsScoreCache = new NeighborsScoreCache(baseGraphIndex); + TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, outputPath); + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath.toAbsolutePath()); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { + TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph); + // We will create a trivial 1:1 mapping between the new graph and the ravv + final int[] graphToRavvOrdMap = IntStream.range(0, allVectorsRavv.size()).toArray(); + OnHeapGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(onDiskGraph, neighborsScoreCache, allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); + + // Verify that the recall is similar + float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateRecall(reconstructedAllNodeOnHeapGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); + float recallFromAllGraphIndex = calculateRecall(allGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); + Assert.assertEquals(recallFromReconstructedAllNodeOnHeapGraphIndex, recallFromAllGraphIndex, 0.01f); + } + } + + public static void validateVectors(OnDiskGraphIndex.View view, RandomAccessVectorValues ravv) { + for (int i = 0; i < view.size(); i++) { + assertEquals("Incorrect vector at " + i, ravv.getVector(i), view.getVector(i)); + } + } + + private VectorFloat createRandomVector(int dimension) { + VectorFloat vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension); + for (int i = 0; i < dimension; i++) { + vector.set(i, (float) Math.random()); + } + return vector; + } + + /** + * Get the ground truth for a query vector + * @param ravv the vectors to search + * @param queryVector the query vector + * @param topK the number of results to return + * @param similarityFunction the similarity function to use + + * @return the ground truth + */ + private static int[] getGroundTruth(RandomAccessVectorValues ravv, VectorFloat queryVector, int topK, VectorSimilarityFunction similarityFunction) { + var exactResults = new ArrayList(); + for (int i = 0; i < ravv.size(); i++) { + float similarityScore = similarityFunction.compare(queryVector, ravv.getVector(i)); + exactResults.add(new SearchResult.NodeScore(i, similarityScore)); + } + exactResults.sort((a, b) -> Float.compare(b.score, a.score)); + return exactResults.stream().limit(topK).mapToInt(nodeScore -> nodeScore.node).toArray(); + } + + private static float calculateRecall(OnHeapGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, VectorFloat queryVector, int[] groundTruth, int k) throws IOException { + try (GraphSearcher graphSearcher = new GraphSearcher(graphIndex)){ + SearchScoreProvider ssp = buildScoreProvider.searchProviderFor(queryVector); + var searchResults = graphSearcher.search(ssp, k, Bits.ALL); + var predicted = Arrays.stream(searchResults.getNodes()).mapToInt(nodeScore -> nodeScore.node).boxed().collect(Collectors.toSet()); + return calculateRecall(predicted, groundTruth, k); + } + } + /** + * Calculate the recall for a set of predicted results + * @param predicted the predicted results + * @param groundTruth the ground truth + * @param k the number of results to consider + * @return the recall + */ + private static float calculateRecall(Set predicted, int[] groundTruth, int k) { + int hits = 0; + int actualK = Math.min(k, Math.min(predicted.size(), groundTruth.length)); + + for (int i = 0; i < actualK; i++) { + for (int j = 0; j < actualK; j++) { + if (predicted.contains(groundTruth[j])) { + hits++; + break; + } + } + } + + return ((float) hits) / (float) actualK; + } +} diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java new file mode 100644 index 000000000..4942b8efb --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java @@ -0,0 +1,72 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.similarity; + +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class BuildScoreProviderTest { + private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + + /** + * Test that the ordinal mapping is correctly applied when creating search and diversity score providers. + */ + @Test + public void testOrdinalMapping() { + final VectorSimilarityFunction vsf = VectorSimilarityFunction.DOT_PRODUCT; + + // Create test vectors + final List> vectors = new ArrayList<>(); + vectors.add(vts.createFloatVector(new float[]{1.0f, 0.0f})); + vectors.add(vts.createFloatVector(new float[]{0.0f, 1.0f})); + vectors.add(vts.createFloatVector(new float[]{-1.0f, 0.0f})); + var ravv = new ListRandomAccessVectorValues(vectors, 2); + + // Create non-identity mapping: graph node 0 -> ravv ordinal 2, graph node 1 -> ravv ordinal 0, graph node 2 -> ravv ordinal 1 + int[] graphToRavvOrdMap = {2, 0, 1}; + + var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, graphToRavvOrdMap, vsf); + + // Test that searchProviderFor(graphNode) uses the correct RAVV ordinal + var ssp0 = bsp.searchProviderFor(0); // should use ravv ordinal 2 (vector [-1, 0]) + var ssp1 = bsp.searchProviderFor(1); // should use ravv ordinal 0 (vector [1, 0]) + var ssp2 = bsp.searchProviderFor(2); // should use ravv ordinal 1 (vector [0, 1]) + + // Verify by computing similarity between graph nodes + // Graph node 0 (vector 2:[-1, 0]) vs graph node 1 (vector 0:[1, 0]) + assertEquals(vsf.compare(vectors.get(2), vectors.get(0)), ssp0.exactScoreFunction().similarityTo(1), 1e-6f); + + // Graph node 1 (vector 0:[1, 0]) vs graph node 0 (vector 2:[-1, 0]) + assertEquals(vsf.compare(vectors.get(0), vectors.get(2)), ssp1.exactScoreFunction().similarityTo(0), 1e-6f); + + // Graph node 2 (vector 1:[0, 1]) vs graph node 1 (vector 0:[1, 0]) + assertEquals(vsf.compare(vectors.get(1), vectors.get(0)), ssp2.exactScoreFunction().similarityTo(1), 1e-6f); + + // Test diversityProviderFor uses same mapping, Graph node 0 (vector 2:[-1, 0]) vs graph node 1 (vector 0:[1, 0]) + var dsp0 = bsp.diversityProviderFor(0); + assertEquals(vsf.compare(vectors.get(2), vectors.get(0)), dsp0.exactScoreFunction().similarityTo(1), 1e-6f); + } +} \ No newline at end of file