Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ public class LuminaVectorGlobalIndexWriter implements GlobalIndexSingletonWriter

private final GlobalIndexFileWriter fileWriter;
private final LuminaVectorIndexOptions options;
private final Map<String, String> luminaOptions;
private final int dim;

/** Temporary file storing records as [rowId (long)][float * dim] in native byte order. */
Expand All @@ -99,14 +100,13 @@ public LuminaVectorGlobalIndexWriter(
LuminaVectorIndexOptions options) {
this.fileWriter = fileWriter;
this.options = options;
this.dim = options.dimension();
this.dim = validateAndResolveDimension(fieldType, options);
this.luminaOptions = options.toLuminaOptions(dim);
this.count = 0;
this.closed = false;
this.recordSizeInBytes = checkedRecordSize(dim, IO_BUFFER_SIZE);
this.vectorBuf = new float[dim];

validateFieldType(fieldType);

try {
this.tempVectorFile = File.createTempFile("lumina-vectors-", ".bin");
this.tempVectorFile.deleteOnExit();
Expand All @@ -120,22 +120,33 @@ public LuminaVectorGlobalIndexWriter(
}
}

private void validateFieldType(DataType dataType) {
private static int validateAndResolveDimension(
DataType dataType, LuminaVectorIndexOptions options) {
if (dataType instanceof VectorType) {
DataType elementType = ((VectorType) dataType).getElementType();
VectorType vectorType = (VectorType) dataType;
DataType elementType = vectorType.getElementType();
if (!(elementType instanceof FloatType)) {
throw new IllegalArgumentException(
"Lumina vector index requires float vector, but got: " + elementType);
}
return;
int typeDimension = vectorType.getLength();
if (options.isDimensionExplicitlyConfigured() && options.dimension() != typeDimension) {
throw new IllegalArgumentException(
String.format(
"%s configured %d conflicts with VECTOR length %d.",
LuminaVectorIndexOptions.DIMENSION.key(),
options.dimension(),
typeDimension));
}
return typeDimension;
}
if (dataType instanceof ArrayType) {
DataType elementType = ((ArrayType) dataType).getElementType();
if (!(elementType instanceof FloatType)) {
throw new IllegalArgumentException(
"Lumina vector index requires float array, but got: " + elementType);
}
return;
return options.dimension();
}
throw new IllegalArgumentException(
"Lumina vector index requires VectorType or ArrayType<FLOAT>, but got: "
Expand Down Expand Up @@ -255,7 +266,7 @@ private ResultEntry buildIndex() throws IOException {

try (LuminaIndex index =
LuminaIndex.createForBuild(
options.indexType(), dim, options.metric(), options.toLuminaOptions())) {
options.indexType(), dim, options.metric(), luminaOptions)) {

// Pretrain and insert via streaming file-backed Dataset API
long phaseStart = System.currentTimeMillis();
Expand Down Expand Up @@ -288,7 +299,7 @@ private ResultEntry buildIndex() throws IOException {
"Lumina index build completed in {} ms",
System.currentTimeMillis() - buildStart);

LuminaIndexMeta meta = new LuminaIndexMeta(options.toLuminaOptions());
LuminaIndexMeta meta = new LuminaIndexMeta(luminaOptions);
// rowCount = logical rows including nulls (not just indexed vectors)
return new ResultEntry(fileName, logicalRowId, meta.serialize());
}
Expand All @@ -300,7 +311,7 @@ private ResultEntry buildIndex() throws IOException {
* thread pool is sized correctly before the builder is created.
*/
private void configureExecutorThreadCount() {
Map<String, String> luminaOpts = options.toLuminaOptions();
Map<String, String> luminaOpts = luminaOptions;
String threadCountKey =
LuminaVectorIndexOptions.toLuminaKey(
LuminaVectorIndexOptions.DISKANN_BUILD_THREAD_COUNT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,18 @@ public class LuminaVectorIndexOptions {
.withDescription("The parallel number for search.");

private final int dimension;
private final boolean dimensionExplicitlyConfigured;
private final LuminaVectorMetric metric;
private final String indexType;
private final Map<String, String> luminaOptions;

public LuminaVectorIndexOptions(Options options) {
this.dimension = validatePositive(options.get(DIMENSION), DIMENSION.key());
this.dimensionExplicitlyConfigured = options.contains(DIMENSION);
this.metric = parseMetric(options.get(DISTANCE_METRIC));
this.indexType = options.get(INDEX_TYPE);
validateEncodingMetricCombination(options.get(ENCODING_TYPE), this.metric);
this.luminaOptions = buildLuminaOptions(options, this.dimension);
this.luminaOptions = buildLuminaOptions(options);
}

/**
Expand All @@ -128,13 +130,26 @@ public LuminaVectorIndexOptions(Options options) {
* diskann.build.ef_construction}.
*/
public Map<String, String> toLuminaOptions() {
return new LinkedHashMap<>(luminaOptions);
return toLuminaOptions(dimension);
}

public Map<String, String> toLuminaOptions(int dimension) {
Map<String, String> result = new LinkedHashMap<>(luminaOptions);
result.put(
toLuminaKey(DIMENSION),
String.valueOf(validatePositive(dimension, DIMENSION.key())));
capPqM(result, dimension);
return result;
}

public int dimension() {
return dimension;
}

public boolean isDimensionExplicitlyConfigured() {
return dimensionExplicitlyConfigured;
}

public LuminaVectorMetric metric() {
return metric;
}
Expand Down Expand Up @@ -177,7 +192,7 @@ public static String toLuminaKey(ConfigOption<?> option) {
* like {@code index.type} are always present in the metadata, matching paimon-cpp behavior.
*/
@SuppressWarnings("unchecked")
private static Map<String, String> buildLuminaOptions(Options options, int dimension) {
private static Map<String, String> buildLuminaOptions(Options options) {
Map<String, String> result = new LinkedHashMap<>();
// Populate all known options with their resolved values (user-set or default).
for (ConfigOption<?> opt : ALL_OPTIONS) {
Expand All @@ -193,8 +208,6 @@ private static Map<String, String> buildLuminaOptions(Options options, int dimen
result.putIfAbsent(key.substring(LUMINA_PREFIX.length()), entry.getValue());
}
}
// PQ encoding requires pq.m <= dimension; auto-cap to avoid native init failures.
capPqM(result, dimension);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.lumina.index;

import org.apache.paimon.fs.PositionOutputStream;
import org.apache.paimon.globalindex.io.GlobalIndexFileWriter;
import org.apache.paimon.options.Options;
import org.apache.paimon.types.ArrayType;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.FloatType;
import org.apache.paimon.types.VectorType;

import org.junit.jupiter.api.Test;

import java.lang.reflect.Field;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/** Unit tests for {@link LuminaVectorGlobalIndexWriter}. */
public class LuminaVectorGlobalIndexWriterTest {

@Test
public void testVectorTypeUsesTypeDimensionByDefault() {
DataType vecFieldType = new VectorType(2, new FloatType());
Options options = new Options();
options.setString(LuminaVectorIndexOptions.DISTANCE_METRIC.key(), "l2");

LuminaVectorIndexOptions indexOptions = new LuminaVectorIndexOptions(options);

try (LuminaVectorGlobalIndexWriter writer =
new LuminaVectorGlobalIndexWriter(
createNoopFileWriter(), vecFieldType, indexOptions)) {
assertThat(writerLuminaOptions(writer))
.containsEntry(
LuminaVectorIndexOptions.toLuminaKey(
LuminaVectorIndexOptions.DIMENSION),
"2");

writer.write(new float[] {1.0f, 0.0f});

assertThatThrownBy(() -> writer.write(new float[] {1.0f, 0.0f, 0.0f}))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("expected 2")
.hasMessageContaining("got 3");
}
}

@Test
public void testVectorTypePqMUsesTypeDimensionForCap() {
assertVectorTypePqMOption(256, 192, 192);
assertVectorTypePqMOption(256, 300, 256);
}

@Test
public void testVectorTypeRejectsExplicitDimensionConflict() {
DataType vecFieldType = new VectorType(2, new FloatType());
Options options = createDefaultOptions(3);
LuminaVectorIndexOptions indexOptions = new LuminaVectorIndexOptions(options);

assertThatThrownBy(
() ->
new LuminaVectorGlobalIndexWriter(
createNoopFileWriter(), vecFieldType, indexOptions))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining(LuminaVectorIndexOptions.DIMENSION.key())
.hasMessageContaining("configured 3")
.hasMessageContaining("VECTOR length 2");
}

@Test
public void testArrayTypeUsesOptionDimension() {
DataType arrayFieldType = new ArrayType(new FloatType());
Options options = createDefaultOptions(3);
LuminaVectorIndexOptions indexOptions = new LuminaVectorIndexOptions(options);

try (LuminaVectorGlobalIndexWriter writer =
new LuminaVectorGlobalIndexWriter(
createNoopFileWriter(), arrayFieldType, indexOptions)) {
writer.write(new float[] {1.0f, 0.0f, 0.0f});

assertThatThrownBy(() -> writer.write(new float[] {1.0f, 0.0f}))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("expected 3")
.hasMessageContaining("got 2");
}
}

private static Options createDefaultOptions(int dimension) {
Options options = new Options();
options.setInteger(LuminaVectorIndexOptions.DIMENSION.key(), dimension);
options.setString(LuminaVectorIndexOptions.DISTANCE_METRIC.key(), "l2");
return options;
}

private static GlobalIndexFileWriter createNoopFileWriter() {
return new GlobalIndexFileWriter() {
@Override
public String newFileName(String prefix) {
return prefix;
}

@Override
public PositionOutputStream newOutputStream(String fileName) {
throw new UnsupportedOperationException();
}
};
}

private static void assertVectorTypePqMOption(
int vectorDimension, int configuredPqM, int expectedPqM) {
DataType vecFieldType = new VectorType(vectorDimension, new FloatType());
Options options = new Options();
options.setString(LuminaVectorIndexOptions.DISTANCE_METRIC.key(), "l2");
options.setInteger(LuminaVectorIndexOptions.ENCODING_PQ_M.key(), configuredPqM);
LuminaVectorIndexOptions indexOptions = new LuminaVectorIndexOptions(options);

try (LuminaVectorGlobalIndexWriter writer =
new LuminaVectorGlobalIndexWriter(
createNoopFileWriter(), vecFieldType, indexOptions)) {
Map<String, String> luminaOptions = writerLuminaOptions(writer);
assertThat(luminaOptions)
.containsEntry(
LuminaVectorIndexOptions.toLuminaKey(
LuminaVectorIndexOptions.DIMENSION),
String.valueOf(vectorDimension))
.containsEntry(
LuminaVectorIndexOptions.toLuminaKey(
LuminaVectorIndexOptions.ENCODING_PQ_M),
String.valueOf(expectedPqM));
}
}

@SuppressWarnings("unchecked")
private static Map<String, String> writerLuminaOptions(LuminaVectorGlobalIndexWriter writer) {
try {
Field field = LuminaVectorGlobalIndexWriter.class.getDeclaredField("luminaOptions");
field.setAccessible(true);
return (Map<String, String>) field.get(writer);
} catch (ReflectiveOperationException e) {
throw new AssertionError(e);
}
}
}