Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ public NeighborWithShortEdges(Neighbors neighbors, double shortEdges) {
}
}

private static class NeighborIterator implements NodesIterator {
public static class NeighborIterator implements NodesIterator {
private final NodeArray neighbors;
private int i;

Expand All @@ -374,5 +374,9 @@ public boolean hasNext() {
public int nextInt() {
return neighbors.getNode(i++);
}

public NodeArray merge(NodeArray other) {
return NodeArray.merge(neighbors, other);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.GraphIndex.NodeAtLevel;
import io.github.jbellis.jvector.graph.SearchResult.NodeScore;
import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
Expand All @@ -38,10 +40,7 @@
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -297,8 +296,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
boolean addHierarchy,
boolean refineFinalGraph,
ForkJoinPool simdExecutor,
ForkJoinPool parallelExecutor)
{
ForkJoinPool parallelExecutor) {
if (maxDegrees.stream().anyMatch(i -> i <= 0)) {
throw new IllegalArgumentException("layer degrees must be positive");
}
Expand Down Expand Up @@ -339,6 +337,50 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
this.rng = new Random(0);
}

/**
* Create this builder from an existing {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex}, this is useful when we just loaded a graph from disk
* copy it into {@link OnHeapGraphIndex} and then start mutating it with minimal overhead of recreating the mutable {@link OnHeapGraphIndex} used in the new GraphIndexBuilder object
*
* @param buildScoreProvider the provider responsible for calculating build scores.
* @param onDiskGraphIndex the on-disk representation of the graph index to be processed and converted.
* @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores,
* organized by levels and nodes.
* @param beamWidth the width of the beam used during the graph building process.
* @param neighborOverflow the factor determining how many additional neighbors are allowed beyond the configured limit.
* @param alpha the weight factor for balancing score computations.
* @param addHierarchy whether to add hierarchical structures while building the graph.
* @param refineFinalGraph whether to perform a refinement step on the final graph structure.
* @param simdExecutor the ForkJoinPool executor used for SIMD tasks during graph building.
* @param parallelExecutor the ForkJoinPool executor used for general parallelization during graph building.
*
* @throws IOException if an I/O error occurs during the graph loading or conversion process.
*/
public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, OnDiskGraphIndex onDiskGraphIndex, NeighborsScoreCache perLevelNeighborsScoreCache, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) throws IOException {
this.scoreProvider = buildScoreProvider;
this.neighborOverflow = neighborOverflow;
this.dimension = onDiskGraphIndex.getDimension();
this.alpha = alpha;
this.addHierarchy = addHierarchy;
this.refineFinalGraph = refineFinalGraph;
this.beamWidth = beamWidth;
this.simdExecutor = simdExecutor;
this.parallelExecutor = parallelExecutor;

this.graph = OnHeapGraphIndex.convertToHeap(onDiskGraphIndex, perLevelNeighborsScoreCache, buildScoreProvider, neighborOverflow, alpha);

this.searchers = ExplicitThreadLocal.withInitial(() -> {
var gs = new GraphSearcher(graph);
gs.usePruning(false);
return gs;
});

// in scratch, we store candidates in reverse order: worse candidates are first
this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));
this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));

this.rng = new Random(0);
}

// used by Cassandra when it fine-tunes the PQ codebook
public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) {
var newBuilder = new GraphIndexBuilder(newProvider,
Expand Down Expand Up @@ -740,6 +782,60 @@ public synchronized long removeDeletedNodes() {
return memorySize;
}

/**
* Convenience method to build a new graph from an existing one, with the addition of new nodes.
* This is useful when we want to merge a new set of vectors into an existing graph that is already on disk.
*
* @param onDiskGraphIndex the on-disk representation of the graph index to be processed and converted.
* @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores,
* @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph
* @param buildScoreProvider the provider responsible for calculating build scores.
* @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start
* @param graphToRavvOrdMap a mapping from the old graph's node ids to the newVectors RAVV node ids
* @param beamWidth the width of the beam used during the graph building process.
* @param overflowRatio the ratio of extra neighbors to allow temporarily when inserting a node.
* @param alpha the weight factor for balancing score computations.
* @param addHierarchy whether to add hierarchical structures while building the graph.
*
* @return the in-memory representation of the graph index.
* @throws IOException if an I/O error occurs during the graph loading or conversion process.
*/
public static OnHeapGraphIndex buildAndMergeNewNodes(OnDiskGraphIndex onDiskGraphIndex,
NeighborsScoreCache perLevelNeighborsScoreCache,
RandomAccessVectorValues newVectors,
BuildScoreProvider buildScoreProvider,
int startingNodeOffset,
int[] graphToRavvOrdMap,
int beamWidth,
float overflowRatio,
float alpha,
boolean addHierarchy) throws IOException {



try (GraphIndexBuilder builder = new GraphIndexBuilder(buildScoreProvider,
onDiskGraphIndex,
perLevelNeighborsScoreCache,
beamWidth,
overflowRatio,
alpha,
addHierarchy,
true,
PhysicalCoreExecutor.pool(),
ForkJoinPool.commonPool())) {

var vv = newVectors.threadLocalSupplier();

// parallel graph construction from the merge documents Ids
PhysicalCoreExecutor.pool().submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> {
builder.addGraphNode(ord, vv.get().getVector(graphToRavvOrdMap[ord]));
})).join();

builder.cleanup();
return builder.getGraph();
}
}

private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) {
// if either natural or concurrent is empty, skip the merge
NodeArray toMerge;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors;
import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.diversity.DiversityProvider;
import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.DenseIntMap;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.util.SparseIntMap;
import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet;
import io.github.jbellis.jvector.util.*;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import org.agrona.collections.IntArrayList;

Expand All @@ -41,9 +40,12 @@
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -267,7 +269,7 @@ public ConcurrentGraphIndexView getView() {
/**
* A View that assumes no concurrent modifications are made
*/
public GraphIndex.View getFrozenView() {
public FrozenView getFrozenView() {
return new FrozenView();
}

Expand Down Expand Up @@ -421,7 +423,7 @@ public boolean hasNext() {
}
}

private class FrozenView implements View {
public class FrozenView implements View {
@Override
public NodesIterator getNeighborsIterator(int level, int node) {
return OnHeapGraphIndex.this.getNeighborsIterator(level, node);
Expand Down Expand Up @@ -572,4 +574,68 @@ private void ensureCapacity(int node) {
}
}
}

/**
* Converts an OnDiskGraphIndex to an OnHeapGraphIndex by copying all nodes, their levels, and neighbors,
* along with other configuration details, from disk-based storage to heap-based storage.
*
* @param diskIndex the disk-based index to be converted
* @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores,
* organized by levels and nodes.
* @param bsp The build score provider to be used for
* @param overflowRatio usually 1.2f
* @param alpha usually 1.2f
* @return an OnHeapGraphIndex that is equivalent to the provided OnDiskGraphIndex but operates in heap memory
* @throws IOException if an I/O error occurs during the conversion process
*/
public static OnHeapGraphIndex convertToHeap(OnDiskGraphIndex diskIndex,
NeighborsScoreCache perLevelNeighborsScoreCache,
BuildScoreProvider bsp,
float overflowRatio,
float alpha) throws IOException {

// Create a new OnHeapGraphIndex with the appropriate configuration
List<Integer> maxDegrees = new ArrayList<>();
for (int level = 0; level <= diskIndex.getMaxLevel(); level++) {
maxDegrees.add(diskIndex.getDegree(level));
}

OnHeapGraphIndex heapIndex = new OnHeapGraphIndex(
maxDegrees,
overflowRatio, // overflow ratio
new VamanaDiversityProvider(bsp, alpha) // diversity provider - can be null for basic usage
);

// Copy all nodes and their connections from disk to heap
try (var view = diskIndex.getView()) {
// Copy nodes level by level
for (int level = 0; level <= diskIndex.getMaxLevel(); level++) {
final NodesIterator nodesIterator = diskIndex.getNodes(level);
final Map<Integer, NodeArray> levelNeighborsScoreCache = perLevelNeighborsScoreCache.getNeighborsScoresInLevel(level);
if (levelNeighborsScoreCache == null) {
throw new IllegalStateException("No neighbors score cache found for level " + level);
}
if (nodesIterator.size() != levelNeighborsScoreCache.size()) {
throw new IllegalStateException("Neighbors score cache size mismatch for level " + level +
". Expected (currently in index): " + nodesIterator.size() + ", but got (in cache): " + levelNeighborsScoreCache.size());
}

while (nodesIterator.hasNext()) {
int nodeId = nodesIterator.next();

// Copy neighbors
final NodeArray neighbors = levelNeighborsScoreCache.get(nodeId).copy();

// Add the node with its neighbors
heapIndex.addNode(level, nodeId, neighbors);
heapIndex.markComplete(new NodeAtLevel(level, nodeId));
}
}

// Set the entry point
heapIndex.updateEntryNode(view.entryNode());
}

return heapIndex;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.graph.disk;

import io.github.jbellis.jvector.disk.IndexWriter;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.*;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
* Cache containing pre-computed neighbor scores, organized by levels and nodes.
* <p>
* This cache bridges the gap between {@link OnDiskGraphIndex} and {@link OnHeapGraphIndex}:
* <ul>
* <li>{@link OnDiskGraphIndex} stores only neighbor IDs (not scores) for space efficiency</li>
* <li>{@link OnHeapGraphIndex} requires neighbor scores for pruning operations</li>
* </ul>
* <p>
* When converting from disk to heap representation, this cache avoids expensive score
* recomputation by providing pre-calculated neighbor scores for all graph levels.
*
* @see OnHeapGraphIndex#convertToHeap(OnDiskGraphIndex, NeighborsScoreCache, BuildScoreProvider, float, float)
*
* This is particularly useful when merging new nodes into an existing graph.
* @see GraphIndexBuilder#buildAndMergeNewNodes(OnDiskGraphIndex, NeighborsScoreCache, RandomAccessVectorValues, BuildScoreProvider, int, int[], int, float, float, boolean)
*/
public class NeighborsScoreCache {
private final Map<Integer, Map<Integer, NodeArray>> perLevelNeighborsScoreCache;

public NeighborsScoreCache(OnHeapGraphIndex graphIndex) throws IOException {
try (OnHeapGraphIndex.FrozenView view = graphIndex.getFrozenView()) {
final Map<Integer, Map<Integer, NodeArray>> perLevelNeighborsScoreCache = new HashMap<>(graphIndex.getMaxLevel() + 1);
for (int level = 0; level <= graphIndex.getMaxLevel(); level++) {
final Map<Integer, NodeArray> levelNeighborsScores = new HashMap<>(graphIndex.size(level) + 1);
final NodesIterator nodesIterator = graphIndex.getNodes(level);
while (nodesIterator.hasNext()) {
final int nodeId = nodesIterator.nextInt();

ConcurrentNeighborMap.NeighborIterator neighborIterator = (ConcurrentNeighborMap.NeighborIterator) view.getNeighborsIterator(level, nodeId);
final NodeArray neighbours = neighborIterator.merge(new NodeArray(neighborIterator.size()));
levelNeighborsScores.put(nodeId, neighbours);
}

perLevelNeighborsScoreCache.put(level, levelNeighborsScores);
}

this.perLevelNeighborsScoreCache = perLevelNeighborsScoreCache;
}
}

public NeighborsScoreCache(RandomAccessReader in) throws IOException {
final int numberOfLevels = in.readInt();
perLevelNeighborsScoreCache = new HashMap<>(numberOfLevels);
for (int i = 0; i < numberOfLevels; i++) {
final int level = in.readInt();
final int numberOfNodesInLevel = in.readInt();
final Map<Integer, NodeArray> levelNeighborsScores = new HashMap<>(numberOfNodesInLevel);
for (int j = 0; j < numberOfNodesInLevel; j++) {
final int nodeId = in.readInt();
final int numberOfNeighbors = in.readInt();
final NodeArray nodeArray = new NodeArray(numberOfNeighbors);
for (int k = 0; k < numberOfNeighbors; k++) {
final int neighborNodeId = in.readInt();
final float neighborScore = in.readFloat();
nodeArray.insertSorted(neighborNodeId, neighborScore);
}
levelNeighborsScores.put(nodeId, nodeArray);
}
perLevelNeighborsScoreCache.put(level, levelNeighborsScores);
}
}

public void write(IndexWriter out) throws IOException {
out.writeInt(perLevelNeighborsScoreCache.size()); // write the number of levels
for (Map.Entry<Integer ,Map<Integer, NodeArray>> levelNeighborsScores : perLevelNeighborsScoreCache.entrySet()) {
final int level = levelNeighborsScores.getKey();
out.writeInt(level);
out.writeInt(levelNeighborsScores.getValue().size()); // write the number of nodes in the level
// Write the neighborhoods for each node in the level
for (Map.Entry<Integer, NodeArray> nodeArrayEntry : levelNeighborsScores.getValue().entrySet()) {
final int nodeId = nodeArrayEntry.getKey();
out.writeInt(nodeId);
final NodeArray nodeArray = nodeArrayEntry.getValue();
out.writeInt(nodeArray.size()); // write the number of neighbors for the node
// Write the nodeArray(neighbors)
for (int i = 0; i < nodeArray.size(); i++) {
out.writeInt(nodeArray.getNode(i));
out.writeFloat(nodeArray.getScore(i));
}
}
}
}

public Map<Integer, NodeArray> getNeighborsScoresInLevel(int level) {
return perLevelNeighborsScoreCache.get(level);
}


}
Loading
Loading