Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@

package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.graph.diversity.DiversityProvider;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.DenseIntMap;
import io.github.jbellis.jvector.util.FixedBitSet;
import io.github.jbellis.jvector.util.IntMap;

import static java.lang.Math.min;

/**
* Encapsulates operations on a graph's neighbors.
*/
Expand Down Expand Up @@ -179,17 +176,13 @@ public static class Neighbors extends NodeArray {
/** the node id whose neighbors we are storing */
private final int nodeId;

/** entries in `nodes` before this index are diverse and don't need to be checked again */
private int diverseBefore;

/**
* uses the node and score references directly from `nodeArray`, without copying
* `nodeArray` is assumed to have had diversity enforced already
*/
private Neighbors(int nodeId, NodeArray nodeArray) {
super(nodeArray);
this.nodeId = nodeId;
this.diverseBefore = size();
}

public NodesIterator iterator() {
Expand Down Expand Up @@ -217,8 +210,7 @@ private NeighborWithShortEdges enforceDegree(ConcurrentNeighborMap map) {
return new NeighborWithShortEdges(this, Double.NaN);
}
var next = copy();
double shortEdges = retainDiverseInternal(next, diverseBefore, map);
next.diverseBefore = next.size();
double shortEdges = retainDiverseInternal(next, map);
return new NeighborWithShortEdges(next, shortEdges);
}

Expand All @@ -234,7 +226,7 @@ private Neighbors replaceDeletedNeighbors(Bits deletedNodes, NodeArray candidate

// merge the remaining neighbors with the candidates and keep the diverse results
NodeArray merged = NodeArray.merge(liveNeighbors, candidates);
retainDiverseInternal(merged, 0, map);
retainDiverseInternal(merged, map);
return new Neighbors(nodeId, merged);
}

Expand All @@ -254,10 +246,10 @@ private Neighbors insertDiverse(NodeArray toMerge, ConcurrentNeighborMap map) {
NodeArray merged;
if (size() > 0) {
merged = NodeArray.merge(this, toMerge);
retainDiverseInternal(merged, 0, map);
retainDiverseInternal(merged, map);
} else {
merged = toMerge.copy(); // still need to copy in case we lose the race
retainDiverseInternal(merged, 0, map);
retainDiverseInternal(merged, map);
}
// insertDiverse usually gets called with a LOT of candidates, so trim down the resulting NodeArray
var nextNodes = merged.getArrayLength() <= map.nodeArrayLength()
Expand All @@ -275,17 +267,16 @@ private Neighbors insertNotDiverse(int node, float score, ConcurrentNeighborMap
// node already existed in the array -- this is rare enough that we don't check up front
return this;
}
next.diverseBefore = min(insertedAt, diverseBefore);
return next;
}

/**
* Retain the diverse neighbors, updating `neighbors` in place
* @return post-diversity short edges fraction
*/
private double retainDiverseInternal(NodeArray neighbors, int diverseBefore, ConcurrentNeighborMap map) {
private double retainDiverseInternal(NodeArray neighbors, ConcurrentNeighborMap map) {
BitSet selected = new FixedBitSet(neighbors.size());
double shortEdges = map.diversityProvider.retainDiverse(neighbors, map.maxDegree, diverseBefore, selected);
double shortEdges = map.diversityProvider.retainDiverse(neighbors, map.maxDegree, selected);
neighbors.retain(selected);
return shortEdges;
}
Expand All @@ -302,7 +293,7 @@ private Neighbors insert(int neighborId, float score, float overflow, Concurrent
assert hardMax <= map.maxOverflowDegree
: String.format("overflow %s could exceed max overflow degree %d", overflow, map.maxOverflowDegree);

var insertionPoint = insertionPoint(neighborId, score);
var insertionPoint = insertionPoint(neighborId);
if (insertionPoint == -1) {
// "new" node already existed
return null;
Expand All @@ -313,31 +304,16 @@ private Neighbors insert(int neighborId, float score, float overflow, Concurrent

// batch up the enforcement of the max connection limit, since otherwise
// we do a lot of duplicate work scanning nodes that we won't remove
next.diverseBefore = min(insertionPoint, diverseBefore);
if (next.size() > hardMax) {
retainDiverseInternal(next, next.diverseBefore, map);
next.diverseBefore = next.size();
retainDiverseInternal(next, map);
}

return next;
}

public static long ramBytesUsed(int count) {
return NodeArray.ramBytesUsed(count) // includes our object header
+ Integer.BYTES // nodeId
+ Integer.BYTES; // diverseBefore
}

/** Only for testing; this is a linear search */
@VisibleForTesting
boolean contains(int i) {
var it = this.iterator();
while (it.hasNext()) {
if (it.nextInt() == i) {
return true;
}
}
return false;
+ Integer.BYTES; // nodeId
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ public class GraphIndexBuilder implements Closeable {
private static final Logger logger = LoggerFactory.getLogger(GraphIndexBuilder.class);

private final int beamWidth;
private final ExplicitThreadLocal<NodeArray> naturalScratch;
private final ExplicitThreadLocal<NodeArray> concurrentScratch;
private final ExplicitThreadLocal<NodeArray> scratch;

private final int dimension;
private final float neighborOverflow;
Expand Down Expand Up @@ -333,8 +332,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
});

// in scratch we store candidates in reverse order: worse candidates are first
this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));
this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));
this.scratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(2 * max(beamWidth, graph.maxDegree() + 1)));

this.rng = new Random(0);
}
Expand Down Expand Up @@ -375,10 +373,10 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi
var newNeighbors = new NodeArray(oldNeighborsIt.size());
while (oldNeighborsIt.hasNext()) {
int neighbor = oldNeighborsIt.nextInt();
// since we're using a different score provider, use insertSorted instead of addInOrder
newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor));
newNeighbors.addInOrder(neighbor, sf.similarityTo(neighbor));
}
newBuilder.graph.addNode(lvl, i, newNeighbors);
newBuilder.graph.markComplete(new NodeAtLevel(lvl, i));
}
});
}).join();
Expand Down Expand Up @@ -555,8 +553,7 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) {
var inProgressBefore = insertionsInProgress.clone();
try (var gs = searchers.get()) {
gs.setView(graph.getView()); // new snapshot
var naturalScratchPooled = naturalScratch.get();
var concurrentScratchPooled = concurrentScratch.get();
var scratchPooled = scratch.get();

var bits = new ExcludingBits(nodeLevel.node);
var entry = graph.entry();
Expand All @@ -579,7 +576,7 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) {
neighbors[index.getAndIncrement()] = new NodeScore(neighbor, score);
});
Arrays.sort(neighbors);
updateNeighborsOneLayer(lvl, nodeLevel.node, neighbors, naturalScratchPooled, inProgressBefore, concurrentScratchPooled, searchScoreProvider);
updateNeighborsOneLayer(lvl, nodeLevel.node, neighbors, scratchPooled, inProgressBefore, searchScoreProvider);
}
gs.setEntryPointsFromPreviousLayer();
}
Expand All @@ -588,7 +585,7 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) {
result = gs.resume(beamWidth, beamWidth, 0.0f, 0.0f);
}

updateNeighborsOneLayer(0, nodeLevel.node, result.getNodes(), naturalScratchPooled, inProgressBefore, concurrentScratchPooled, searchScoreProvider);
updateNeighborsOneLayer(0, nodeLevel.node, result.getNodes(), scratchPooled, inProgressBefore, searchScoreProvider);

graph.markComplete(nodeLevel);
} catch (Exception e) {
Expand All @@ -600,16 +597,20 @@ public long addGraphNode(int node, SearchScoreProvider searchScoreProvider) {
return IntStream.rangeClosed(0, nodeLevel.level).mapToLong(graph::ramBytesUsedOneNode).sum();
}

private void updateNeighborsOneLayer(int layer, int node, NodeScore[] neighbors, NodeArray naturalScratchPooled, ConcurrentSkipListSet<NodeAtLevel> inProgressBefore, NodeArray concurrentScratchPooled, SearchScoreProvider ssp) {
private void updateNeighborsOneLayer(int layer, int node, NodeScore[] neighbors, NodeArray scratchPooled, ConcurrentSkipListSet<NodeAtLevel> inProgressBefore, SearchScoreProvider ssp) {
// Update neighbors with these candidates.
// The DiskANN paper calls for using the entire set of visited nodes along the search path as
// potential candidates, but in practice we observe neighbor lists being completely filled using
// just the topK results. (Since the Robust Prune algorithm prioritizes closer neighbors,
// this means that considering additional nodes from the search path, that are by definition
// farther away than the ones in the topK, would not change the result.)
var natural = toScratchCandidates(neighbors, naturalScratchPooled);
var concurrent = getConcurrentCandidates(layer, node, inProgressBefore, concurrentScratchPooled, ssp.scoreFunction());
updateNeighbors(layer, node, natural, concurrent);
scratchPooled.clear();
toScratchCandidates(neighbors, scratchPooled);
getConcurrentCandidates(layer, node, inProgressBefore, scratchPooled, ssp.scoreFunction());

// toMerge may be approximate-scored, but insertDiverse will compute exact scores for the diverse ones
var graphNeighbors = graph.layers.get(layer).insertDiverse(node, scratchPooled);
graph.layers.get(layer).backlink(graphNeighbors, node, neighborOverflow);
}

@VisibleForTesting
Expand Down Expand Up @@ -740,25 +741,9 @@ public synchronized long removeDeletedNodes() {
return memorySize;
}

private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) {
// if either natural or concurrent is empty, skip the merge
NodeArray toMerge;
if (concurrent.size() == 0) {
toMerge = natural;
} else if (natural.size() == 0) {
toMerge = concurrent;
} else {
toMerge = NodeArray.merge(natural, concurrent);
}
// toMerge may be approximate-scored, but insertDiverse will compute exact scores for the diverse ones
var neighbors = graph.layers.get(layer).insertDiverse(nodeId, toMerge);
graph.layers.get(layer).backlink(neighbors, nodeId, neighborOverflow);
}

private static NodeArray toScratchCandidates(NodeScore[] candidates, NodeArray scratch) {
scratch.clear();
for (var candidate : candidates) {
scratch.addInOrder(candidate.node, candidate.score);
scratch.insertSorted(candidate.node, candidate.score);
}
return scratch;
}
Expand All @@ -769,7 +754,6 @@ private NodeArray getConcurrentCandidates(int layer,
NodeArray scratch,
ScoreFunction scoreFunction)
{
scratch.clear();
for (NodeAtLevel n : inProgress) {
if (n.node == newNode || n.level < layer) {
continue;
Expand Down
Loading
Loading