Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import io.github.jbellis.jvector.annotations.Experimental;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors;
import io.github.jbellis.jvector.graph.disk.AbstractGraphIndexWriter;
import io.github.jbellis.jvector.graph.disk.OrdinalMapper;
import io.github.jbellis.jvector.graph.diversity.DiversityProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.util.Accountable;
Expand All @@ -49,7 +51,6 @@
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.StampedLock;
import java.util.function.Function;
import java.util.stream.IntStream;

/**
Expand All @@ -62,6 +63,7 @@
public class OnHeapGraphIndex implements MutableGraphIndex {
// Used for saving and loading OnHeapGraphIndex
public static final int MAGIC = 0x75EC4012; // JVECTOR, with some imagination
public static final int SERIALIZE_VERSION = 4;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is JVector version, consider writing out the entire version (e.g. 4.0.0-rc4)?


// The current entry node for searches
private final AtomicReference<NodeAtLevel> entryPoint;
Expand Down Expand Up @@ -523,17 +525,43 @@ public String toString() {
}

/**
* Saves the graph to the given DataOutput for reloading into memory later
* <p>Saves the graph to the given DataOutput for reloading into memory later.
*
* <p>
* Does not alter existing ordinals even if the ordinals are not compact,
* which can happen if some nodes were deleted.
*/
@Experimental
@Deprecated
public void save(DataOutput out) throws IOException {
save(out, new OrdinalMapper.IdentityMapper(getIdUpperBound() - 1));
}

/**
* <p>Saves the graph to the given DataOutput for reloading into memory later.
*
* <p>
* Ensures that ordinals holes (from deletions) are not present in the saved graph
* By creating a compact mapping which preserves the relative order of the existing nodes.
*/
@Experimental
@Deprecated
public void saveWithCompactOrdinals(DataOutput out) throws IOException {
save(out, new OrdinalMapper.MapMapper(AbstractGraphIndexWriter.sequentialRenumbering(this)));
}

/**
* Saves the graph to the given DataOutput for reloading into memory later
*/
@Experimental
@Deprecated
public void save(DataOutput out, OrdinalMapper mapper) throws IOException {
if (!allMutationsCompleted()) {
throw new IllegalStateException("Cannot save a graph with pending mutations. Call cleanup() first");
}

out.writeInt(OnHeapGraphIndex.MAGIC); // the magic number
out.writeInt(4); // The version
out.writeInt(SERIALIZE_VERSION); // The version

// Write graph-level properties.
out.writeInt(layers.size());
Expand All @@ -543,7 +571,7 @@ public void save(DataOutput out) throws IOException {

var entryNode = entryPoint.get();
assert entryNode.level == getMaxLevel();
out.writeInt(entryNode.node);
out.writeInt(mapper.oldToNew(entryNode.node));

for (int level = 0; level < layers.size(); level++) {
out.writeInt(size(level));
Expand All @@ -553,19 +581,19 @@ public void save(DataOutput out) throws IOException {
while (it.hasNext()) {
int nodeId = it.nextInt();
var neighbors = layers.get(level).get(nodeId);
out.writeInt(nodeId);
out.writeInt(mapper.oldToNew(nodeId));
out.writeInt(neighbors.size());

for (int n = 0; n < neighbors.size(); n++) {
out.writeInt(neighbors.getNode(n));
out.writeInt(mapper.oldToNew(neighbors.getNode(n)));
out.writeFloat(neighbors.getScore(n));
}
}
}
}

/**
* Saves the graph to the given DataOutput for reloading into memory later
* Loads the graph from the given RandomAccessReader
*/
@Experimental
@Deprecated
Expand All @@ -576,8 +604,8 @@ public static OnHeapGraphIndex load(RandomAccessReader in, int dimension, double
}

int version = in.readInt(); // The version
if (version != 4) {
throw new IOException("Unsupported version: " + version);
if (version != SERIALIZE_VERSION) {
throw new IOException("Unsupported version: " + version + ", expected " + SERIALIZE_VERSION);
}

// Write graph-level properties.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.github.jbellis.jvector.disk.SimpleMappedReader;
import io.github.jbellis.jvector.disk.SimpleWriter;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.disk.OrdinalMapper;
import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
Expand All @@ -42,11 +43,12 @@
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.apache.commons.lang3.ArrayUtils.shuffle;
import static org.junit.Assert.assertEquals;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
Expand Down Expand Up @@ -79,8 +81,8 @@ public class OnHeapGraphIndexTest extends RandomizedTest {
private ArrayList<int[]> groundTruthAllVectors;
private BuildScoreProvider baseBuildScoreProvider;
private BuildScoreProvider allBuildScoreProvider;
private ImmutableGraphIndex baseGraphIndex;
private ImmutableGraphIndex allGraphIndex;
private OnHeapGraphIndex baseGraphIndex;
private OnHeapGraphIndex allGraphIndex;

@Before
public void setup() throws IOException {
Expand Down Expand Up @@ -118,23 +120,30 @@ public void setup() throws IOException {
// score provider using the raw, in-memory vectors
baseBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, SIMILARITY_FUNCTION);
allBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(allVectorsRavv, SIMILARITY_FUNCTION);
var baseGraphIndexBuilder = new GraphIndexBuilder(baseBuildScoreProvider,
try (
var baseGraphIndexBuilder = new GraphIndexBuilder(baseBuildScoreProvider,
baseVectorsRavv.dimension(),
M, // graph degree
BEAM_WIDTH, // construction search depth
NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor
ALPHA, // relax neighbor diversity requirement by this factor
ADD_HIERARCHY); // add the hierarchy
var allGraphIndexBuilder = new GraphIndexBuilder(allBuildScoreProvider,
) {
baseGraphIndex = (OnHeapGraphIndex) baseGraphIndexBuilder.build(baseVectorsRavv);
}

try (
var allGraphIndexBuilder = new GraphIndexBuilder(allBuildScoreProvider,
allVectorsRavv.dimension(),
M, // graph degree
BEAM_WIDTH, // construction search depth
NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor
ALPHA, // relax neighbor diversity requirement by this factor
ADD_HIERARCHY); // add the hierarchy
) {
allGraphIndex = (OnHeapGraphIndex) allGraphIndexBuilder.build(allVectorsRavv);
}

baseGraphIndex = baseGraphIndexBuilder.build(baseVectorsRavv);
allGraphIndex = allGraphIndexBuilder.build(allVectorsRavv);
}

@After
Expand All @@ -148,7 +157,7 @@ public void tearDown() {
* @throws IOException exception
*/
@Test
public void testGraphConstructionWithNonIdentityOrdinalMapping() throws IOException {
public void testGraphConstructionWithRemappedRavv() throws IOException {
// create reversed mapping from graph node id to ravv ordinal
int[] graphToRavvOrdMap = IntStream.range(0, baseVectorsRavv.size()).map(i -> baseVectorsRavv.size() - 1 - i).toArray();
final RemappedRandomAccessVectorValues remappedBaseVectorsRavv = new RemappedRandomAccessVectorValues(baseVectorsRavv, graphToRavvOrdMap);
Expand Down Expand Up @@ -212,35 +221,33 @@ public void testReconstructionOfOnHeapGraphIndex_withIdentityOrdinalMapping() th
*/
@Test
public void testReconstructionOfOnHeapGraphIndex_withNonIdentityOrdinalMapping() throws IOException {
var graphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName());
var heapGraphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap");

// create reversed mapping from graph node id to ravv ordinal
int[] graphToRavvOrdMap = IntStream.range(0, baseVectorsRavv.size()).map(i -> baseVectorsRavv.size() - 1 - i).toArray();
final RemappedRandomAccessVectorValues remmappedRavv = new RemappedRandomAccessVectorValues(baseVectorsRavv, graphToRavvOrdMap);
var bsp = BuildScoreProvider.randomAccessScoreProvider(remmappedRavv, SIMILARITY_FUNCTION);
try (var baseGraphIndexBuilder = new GraphIndexBuilder(bsp,
baseVectorsRavv.dimension(),
M, // graph degree
BEAM_WIDTH, // construction search depth
NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor
ALPHA, // relax neighbor diversity requirement by this factor
ADD_HIERARCHY); // add the hierarchy) {
var baseGraphIndex = baseGraphIndexBuilder.build(remmappedRavv)) {
log.info("Writing graph to {}", graphOutputPath);
TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, graphOutputPath);
// shuffle the ordinals arbitrarily, ensuring the output ordinal space is sparse (with high probability)
int[] oldToNewOrds = getRandom().ints(0, (int) (1.5 * baseVectorsRavv.size())).distinct().limit(baseVectorsRavv.size()).toArray();

log.info("Writing on-heap graph to {}", heapGraphOutputPath);
try (SimpleWriter writer = new SimpleWriter(heapGraphOutputPath.toAbsolutePath())) {
((OnHeapGraphIndex) baseGraphIndex).save(writer);
}
Map<Integer, Integer> oldToNewOrdMap = IntStream.range(0, oldToNewOrds.length).boxed().collect(Collectors.toMap(i -> i, i -> oldToNewOrds[i]));
OrdinalMapper ordinalMapper = new OrdinalMapper.MapMapper(oldToNewOrdMap);

log.info("Reading on-heap graph from {}", heapGraphOutputPath);
try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) {
MutableGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex.load(readerSupplier.get(), baseVectorsRavv.dimension(), NEIGHBOR_OVERFLOW, new VamanaDiversityProvider(bsp, ALPHA));
TestUtil.assertGraphEquals(baseGraphIndex, reconstructedOnHeapGraphIndex);
}
var shuffledRavv = new OrdinalMappedRavv(baseVectorsRavv, ordinalMapper);
var shuffledBsp = BuildScoreProvider.randomAccessScoreProvider(shuffledRavv, SIMILARITY_FUNCTION);
var shuffledGt = transformGroundTruth(groundTruthBaseVectors, ordinalMapper);

log.info("Writing on-heap graph to {}, with ordinal shuffling", heapGraphOutputPath);
try (SimpleWriter writer = new SimpleWriter(heapGraphOutputPath.toAbsolutePath())) {
((OnHeapGraphIndex) baseGraphIndex).save(writer, new OrdinalMapper.MapMapper(oldToNewOrdMap));
}

log.info("Reading on-heap graph from {}", heapGraphOutputPath);
OnHeapGraphIndex deserializedGraph;
try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) {
deserializedGraph = OnHeapGraphIndex.load(readerSupplier.get(), baseVectorsRavv.dimension(), NEIGHBOR_OVERFLOW, new VamanaDiversityProvider(shuffledBsp, ALPHA));
}

var baseRecall = calculateAverageRecall(baseGraphIndex, baseBuildScoreProvider, queryVectors, groundTruthBaseVectors, TOP_K, null);
var deserializedRecall = calculateAverageRecall(deserializedGraph, shuffledBsp, queryVectors, shuffledGt, TOP_K, null);

assertEquals(baseRecall, deserializedRecall, 1e-6);
}

/**
Expand Down Expand Up @@ -336,7 +343,7 @@ public static void validateVectors(OnDiskGraphIndex.View view, RandomAccessVecto
private VectorFloat<?> createRandomVector(int dimension) {
VectorFloat<?> vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension);
for (int i = 0; i < dimension; i++) {
vector.set(i, (float) Math.random());
vector.set(i, randomFloat());
}
return vector;
}
Expand All @@ -360,6 +367,10 @@ private static int[] getGroundTruth(RandomAccessVectorValues ravv, VectorFloat<?
return exactResults.stream().limit(topK).mapToInt(nodeScore -> nodeScore.node).toArray();
}

private static List<int[]> transformGroundTruth(List<int[]> groundTruth, OrdinalMapper mapper) {
return groundTruth.stream().map(gt -> Arrays.stream(gt).map(mapper::oldToNew).toArray()).collect(Collectors.toList());
}

/**
* Calculate average recall across multiple query vectors for more stable measurements
* @param graphIndex the graph index to search
Expand All @@ -371,7 +382,7 @@ private static int[] getGroundTruth(RandomAccessVectorValues ravv, VectorFloat<?
* @return the average recall across all queries
*/
private static float calculateAverageRecall(ImmutableGraphIndex graphIndex, BuildScoreProvider buildScoreProvider,
ArrayList<VectorFloat<?>> queryVectors, ArrayList<int[]> groundTruths,
List<VectorFloat<?>> queryVectors, List<int[]> groundTruths,
int k, int[] graphToRavvOrdMap) throws IOException {
float totalRecall = 0.0f;
for (int i = 0; i < queryVectors.size(); i++) {
Expand Down Expand Up @@ -424,4 +435,39 @@ private static float calculateRecall(Set<Integer> predicted, int[] groundTruth,

return ((float) hits) / (float) actualK;
}

class OrdinalMappedRavv implements RandomAccessVectorValues {
private final RandomAccessVectorValues ravv;
private final OrdinalMapper ordinalMapper;

public OrdinalMappedRavv(RandomAccessVectorValues ravv, OrdinalMapper ordinalMapper) {
this.ravv = ravv;
this.ordinalMapper = ordinalMapper;
}

@Override
public int size() {
throw new UnsupportedOperationException("This RAVV doesn't know it's size");
}

@Override
public int dimension() {
return ravv.dimension();
}

@Override
public VectorFloat<?> getVector(int nodeId) {
return ravv.getVector(ordinalMapper.newToOld(nodeId));
}

@Override
public boolean isValueShared() {
return ravv.isValueShared();
}

@Override
public RandomAccessVectorValues copy() {
return new OrdinalMappedRavv(ravv.copy(), ordinalMapper);
}
}
}