From 8f87ccdf5062a2c2d77640038dc00f4663d26750 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=A3=9E?= Date: Thu, 14 May 2026 11:37:50 -0400 Subject: [PATCH 1/5] feat(java): add zvec-java API and backends --- java/zvec-java/pom.xml | 73 ++ java/zvec-java/zvec-java-api/pom.xml | 71 ++ .../src/main/java/org/zvec/Collection.java | 125 +++ .../main/java/org/zvec/CollectionSchema.java | 96 ++ .../src/main/java/org/zvec/DataType.java | 25 + .../src/main/java/org/zvec/Doc.java | 99 ++ .../src/main/java/org/zvec/FieldSchema.java | 60 ++ .../main/java/org/zvec/HnswIndexParams.java | 51 + .../main/java/org/zvec/HnswQueryParams.java | 76 ++ .../src/main/java/org/zvec/TuningProfile.java | 7 + .../src/main/java/org/zvec/VectorQuery.java | 106 ++ .../src/main/java/org/zvec/VectorSchema.java | 127 +++ .../src/main/java/org/zvec/Zvec.java | 136 +++ .../src/main/java/org/zvec/ZvecSchemas.java | 165 +++ .../src/main/java/org/zvec/ZvecSearch.java | 62 ++ .../main/java/org/zvec/crypto/AadEncoder.java | 25 + .../src/main/java/org/zvec/crypto/Aead.java | 10 + .../main/java/org/zvec/crypto/AesGcm256.java | 58 + .../crypto/AuthenticationFailedException.java | 13 + .../org/zvec/crypto/DecryptingProjector.java | 98 ++ .../org/zvec/crypto/DecryptionException.java | 13 + .../crypto/EncryptedCollectionException.java | 9 + .../java/org/zvec/crypto/EncryptedSchema.java | 72 ++ .../org/zvec/crypto/EncryptingInsertor.java | 83 ++ .../crypto/EncryptionConfigException.java | 13 + .../org/zvec/crypto/EncryptionException.java | 13 + .../crypto/EncryptionFailedException.java | 9 + .../org/zvec/crypto/EncryptionMetadata.java | 82 ++ .../crypto/EncryptionMetadataIOException.java | 9 + .../EncryptionMetadataMismatchException.java | 9 + .../crypto/EncryptionRuntimeException.java | 13 + .../java/org/zvec/crypto/EncryptionSpec.java | 76 ++ .../main/java/org/zvec/crypto/Envelope.java | 96 ++ .../java/org/zvec/crypto/EnvelopeCodec.java | 72 ++ .../zvec/crypto/EnvelopeFormatException.java | 9 + .../org/zvec/crypto/FilterFieldScanner.java | 56 + .../java/org/zvec/crypto/KeyProvider.java | 23 + .../zvec/crypto/KeyResolutionException.java | 13 + .../java/org/zvec/crypto/SidecarJson.java | 212 ++++ .../java/org/zvec/crypto/SidecarMetadata.java | 48 + .../org/zvec/crypto/SingletonKeyProvider.java | 29 + .../crypto/UnsupportedFieldTypeException.java | 9 + .../java/org/zvec/internal/HnswDefaults.java | 135 +++ .../java/org/zvec/internal/NativeBackend.java | 32 + .../zvec/internal/NativeBackendProvider.java | 6 + .../org/zvec/internal/NativeBackends.java | 72 ++ .../java/org/zvec/internal/NativeHandle.java | 3 + .../org/zvec/internal/NativeOpenResult.java | 44 + .../zvec/internal/SchemaMetadataStore.java | 132 +++ .../java/org/zvec/internal/ZvecException.java | 14 + .../perf/CollectionConcurrentStressMain.java | 492 +++++++++ .../org/zvec/perf/CollectionStressMain.java | 438 ++++++++ .../main/java/org/zvec/perf/LatencyStats.java | 97 ++ .../src/main/java/org/zvec/perf/PerfData.java | 106 ++ .../java/org/zvec/perf/StressOptions.java | 334 ++++++ java/zvec-java/zvec-java-ffm/pom.xml | 104 ++ .../org/zvec/internal/ffm/FfmCollections.java | 225 ++++ .../java/org/zvec/internal/ffm/FfmDocs.java | 428 ++++++++ .../java/org/zvec/internal/ffm/FfmHandle.java | 11 + .../java/org/zvec/internal/ffm/FfmNative.java | 511 +++++++++ .../zvec/internal/ffm/FfmNativeBackend.java | 67 ++ .../ffm/FfmNativeBackendProvider.java | 11 + .../zvec/internal/ffm/FfmNativeLoader.java | 115 ++ .../org/zvec/internal/ffm/FfmQueries.java | 169 +++ .../org/zvec/internal/ffm/FfmSchemas.java | 257 +++++ .../org.zvec.internal.NativeBackendProvider | 1 + java/zvec-java/zvec-java-jni/pom.xml | 81 ++ .../java/org/zvec/internal/jni/JniHandle.java | 41 + .../java/org/zvec/internal/jni/JniNative.java | 36 + .../zvec/internal/jni/JniNativeBackend.java | 89 ++ .../jni/JniNativeBackendProvider.java | 11 + .../zvec/internal/jni/JniNativeLoader.java | 123 +++ .../src/main/native/zvec_java_jni.cc | 991 ++++++++++++++++++ .../org.zvec.internal.NativeBackendProvider | 1 + src/binding/c/c_api.cc | 3 +- src/include/zvec/c_api.h | 2 +- src/include/zvec/db/schema.h | 4 +- 77 files changed, 7632 insertions(+), 5 deletions(-) create mode 100644 java/zvec-java/pom.xml create mode 100644 java/zvec-java/zvec-java-api/pom.xml create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/Collection.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/CollectionSchema.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/DataType.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/Doc.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/FieldSchema.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/HnswIndexParams.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/HnswQueryParams.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/TuningProfile.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/VectorQuery.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/VectorSchema.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/Zvec.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/ZvecSchemas.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/ZvecSearch.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AadEncoder.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/Aead.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AesGcm256.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AuthenticationFailedException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/DecryptingProjector.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/DecryptionException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptedCollectionException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptedSchema.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptingInsertor.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionConfigException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionFailedException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadata.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadataIOException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadataMismatchException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionRuntimeException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionSpec.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/Envelope.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EnvelopeCodec.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EnvelopeFormatException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/FilterFieldScanner.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/KeyProvider.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/KeyResolutionException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SidecarJson.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SidecarMetadata.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SingletonKeyProvider.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/UnsupportedFieldTypeException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/HnswDefaults.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackend.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackendProvider.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackends.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeHandle.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeOpenResult.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/SchemaMetadataStore.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/ZvecException.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/CollectionConcurrentStressMain.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/CollectionStressMain.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/LatencyStats.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/PerfData.java create mode 100644 java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/StressOptions.java create mode 100644 java/zvec-java/zvec-java-ffm/pom.xml create mode 100644 java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmCollections.java create mode 100644 java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmDocs.java create mode 100644 java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmHandle.java create mode 100644 java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNative.java create mode 100644 java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeBackend.java create mode 100644 java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeBackendProvider.java create mode 100644 java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeLoader.java create mode 100644 java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmQueries.java create mode 100644 java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmSchemas.java create mode 100644 java/zvec-java/zvec-java-ffm/src/main/resources/META-INF/services/org.zvec.internal.NativeBackendProvider create mode 100644 java/zvec-java/zvec-java-jni/pom.xml create mode 100644 java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniHandle.java create mode 100644 java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNative.java create mode 100644 java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeBackend.java create mode 100644 java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeBackendProvider.java create mode 100644 java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeLoader.java create mode 100644 java/zvec-java/zvec-java-jni/src/main/native/zvec_java_jni.cc create mode 100644 java/zvec-java/zvec-java-jni/src/main/resources/META-INF/services/org.zvec.internal.NativeBackendProvider diff --git a/java/zvec-java/pom.xml b/java/zvec-java/pom.xml new file mode 100644 index 000000000..711b43b36 --- /dev/null +++ b/java/zvec-java/pom.xml @@ -0,0 +1,73 @@ + + 4.0.0 + org.zvec + zvec-java-parent + 0.0.1-SNAPSHOT + pom + zvec-java-parent + + + zvec-java-api + zvec-java-jni + zvec-java-ffm + + + + UTF-8 + 1.37 + 5.12.2 + 3.14.0 + 3.5.0 + 3.5.2 + 3.5.0 + host + + + + + + org.junit.jupiter + junit-jupiter + ${junit.version} + + + org.openjdk.jmh + jmh-core + ${jmh.version} + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + ${maven.compiler.plugin.version} + + + org.apache.maven.plugins + maven-surefire-plugin + ${maven.surefire.plugin.version} + + false + + + + org.codehaus.mojo + exec-maven-plugin + ${exec.maven.plugin.version} + + + org.apache.maven.plugins + maven-enforcer-plugin + ${maven.enforcer.plugin.version} + + + + + + diff --git a/java/zvec-java/zvec-java-api/pom.xml b/java/zvec-java/zvec-java-api/pom.xml new file mode 100644 index 000000000..aea9c6e47 --- /dev/null +++ b/java/zvec-java/zvec-java-api/pom.xml @@ -0,0 +1,71 @@ + + 4.0.0 + + + org.zvec + zvec-java-parent + 0.0.1-SNAPSHOT + + + zvec-java-api + zvec-java-api + + + 11 + + + + + org.junit.jupiter + junit-jupiter + test + + + org.openjdk.jmh + jmh-core + test + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.4.2 + + + test-jar + process-test-classes + + test-jar + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + ${maven.compiler.release} + + + org.openjdk.jmh + jmh-generator-annprocess + ${jmh.version} + + + + + org.openjdk.jmh + jmh-generator-annprocess + ${jmh.version} + + + + + + + diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/Collection.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/Collection.java new file mode 100644 index 000000000..61594789e --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/Collection.java @@ -0,0 +1,125 @@ +package org.zvec; + +import java.util.List; +import java.util.Objects; +import org.zvec.crypto.DecryptingProjector; +import org.zvec.crypto.EncryptedSchema; +import org.zvec.crypto.EncryptingInsertor; +import org.zvec.crypto.EncryptionMetadata; +import org.zvec.crypto.EncryptionSpec; +import org.zvec.crypto.SidecarMetadata; +import org.zvec.internal.NativeBackend; +import org.zvec.internal.NativeHandle; + +public final class Collection implements AutoCloseable { + private final NativeBackend backend; + private final NativeHandle handle; + private final CollectionSchema schema; + private final CollectionSchema querySchema; + private final String collectionPath; + private boolean closed; + private EncryptedSchema encryptedSchema = EncryptedSchema.NONE; + + Collection( + NativeBackend backend, + NativeHandle handle, + CollectionSchema schema, + CollectionSchema querySchema, + String collectionPath) { + this.backend = Objects.requireNonNull(backend, "backend"); + this.handle = Objects.requireNonNull(handle, "handle"); + this.schema = Objects.requireNonNull(schema, "schema"); + this.querySchema = Objects.requireNonNull(querySchema, "querySchema"); + this.collectionPath = Objects.requireNonNull(collectionPath, "collectionPath"); + } + + void attachEncryption(EncryptedSchema es) { + this.encryptedSchema = Objects.requireNonNull(es, "encryptedSchema"); + } + + public EncryptedSchema encryptedSchema() { return encryptedSchema; } + + public CollectionSchema schema() { + return schema; + } + + public void flush() { + requireOpen(); + backend.flush(handle); + } + + public int insert(List docs) { + requireOpen(); + Objects.requireNonNull(docs, "docs"); + List toInsert = docs; + if (encryptedSchema != EncryptedSchema.NONE) { + toInsert = EncryptingInsertor.transform(docs, encryptedSchema); + } + return backend.insert(handle, schema, toInsert); + } + + public List query(VectorQuery query) { + requireOpen(); + Objects.requireNonNull(query, "query"); + if (encryptedSchema != EncryptedSchema.NONE) { + checkFilterDoesNotReferenceEncryptedFields(query); + } + List raw = backend.query(handle, querySchema, schema, query); + if (encryptedSchema != EncryptedSchema.NONE) { + raw = DecryptingProjector.transform(raw, encryptedSchema); + } + return raw; + } + + private void checkFilterDoesNotReferenceEncryptedFields(VectorQuery query) { + String filter = query.filter(); + if (filter == null) return; + java.util.Set referenced = org.zvec.crypto.FilterFieldScanner.referencedFields(filter); + java.util.Set encrypted = encryptedSchema.encryptedFieldNames(); + for (String fieldName : referenced) { + if (encrypted.contains(fieldName)) { + throw new IllegalArgumentException( + "filter cannot reference encrypted field '" + fieldName + "'"); + } + } + } + + public void setActiveKeyId(String fieldName, String newKeyId) { + requireOpen(); + Objects.requireNonNull(fieldName, "fieldName"); + Objects.requireNonNull(newKeyId, "newKeyId"); + if (encryptedSchema == EncryptedSchema.NONE + || !encryptedSchema.isEncrypted(fieldName)) { + throw new IllegalArgumentException( + "field '" + fieldName + "' is not declared as encrypted"); + } + EncryptionMetadata current = encryptedSchema.metadata(); + java.util.LinkedHashMap updated = + new java.util.LinkedHashMap<>(current.fields()); + EncryptionSpec old = updated.get(fieldName); + EncryptionSpec next = new EncryptionSpec( + old.alg(), newKeyId, old.createdAt(), java.time.Instant.now()); + updated.put(fieldName, next); + EncryptionMetadata fresh = new EncryptionMetadata( + current.version(), current.collectionName(), updated); + + java.nio.file.Path dir = java.nio.file.Paths.get(collectionPath); + SidecarMetadata.write(dir, fresh); + encryptedSchema = EncryptedSchema.reconcile(schema, fresh, encryptedSchema.keyProvider()); + } + + @Override + public void close() { + if (closed) { + return; + } + backend.close(handle); + closed = true; + } + + private void requireOpen() { + if (closed) { + throw new IllegalStateException("Collection is already closed"); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/CollectionSchema.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/CollectionSchema.java new file mode 100644 index 000000000..bf0c5add7 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/CollectionSchema.java @@ -0,0 +1,96 @@ +package org.zvec; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import org.zvec.crypto.EncryptionMetadata; +import org.zvec.crypto.KeyProvider; + +public final class CollectionSchema { + private final String name; + private final List fields; + private final List vectors; + private final Map fieldByName; + private final Map vectorByName; + private final EncryptionMetadata encryption; + private final Map embeddedKeyProviders; + + public CollectionSchema(String name, List fields, List vectors) { + this(name, fields, vectors, null); + } + + public CollectionSchema( + String name, + List fields, + List vectors, + EncryptionMetadata encryption) { + this(name, fields, vectors, encryption, null); + } + + public CollectionSchema( + String name, + List fields, + List vectors, + EncryptionMetadata encryption, + Map embeddedKeyProviders) { + this.name = Objects.requireNonNull(name, "name"); + this.fields = List.copyOf(Objects.requireNonNull(fields, "fields")); + this.vectors = List.copyOf(Objects.requireNonNull(vectors, "vectors")); + this.encryption = encryption; + this.embeddedKeyProviders = embeddedKeyProviders == null ? null : Map.copyOf(embeddedKeyProviders); + + Map fieldIndex = new HashMap<>(); + for (FieldSchema field : this.fields) { + FieldSchema previous = + fieldIndex.put(Objects.requireNonNull(field, "field").name(), field); + if (previous != null) { + throw new IllegalArgumentException("Duplicate field name: " + field.name()); + } + } + + Map vectorIndex = new HashMap<>(); + for (VectorSchema vector : this.vectors) { + VectorSchema previous = + vectorIndex.put(Objects.requireNonNull(vector, "vector").name(), vector); + if (previous != null) { + throw new IllegalArgumentException("Duplicate vector name: " + vector.name()); + } + if (fieldIndex.containsKey(vector.name())) { + throw new IllegalArgumentException("Duplicate schema field name: " + vector.name()); + } + } + + this.fieldByName = Map.copyOf(fieldIndex); + this.vectorByName = Map.copyOf(vectorIndex); + } + + public String name() { + return name; + } + + public List fields() { + return fields; + } + + public List vectors() { + return vectors; + } + + public FieldSchema field(String name) { + return fieldByName.get(name); + } + + public VectorSchema vector(String name) { + return vectorByName.get(name); + } + + public Optional encryption() { + return Optional.ofNullable(encryption); + } + + public Optional> embeddedKeyProviders() { + return Optional.ofNullable(embeddedKeyProviders); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/DataType.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/DataType.java new file mode 100644 index 000000000..775708b16 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/DataType.java @@ -0,0 +1,25 @@ +package org.zvec; + +public enum DataType { + STRING(2, false), + BOOL(3, false), + INT64(5, false), + DOUBLE(9, false), + VECTOR_FP32(23, true); + + private final int code; + private final boolean vector; + + DataType(int code, boolean vector) { + this.code = code; + this.vector = vector; + } + + public int code() { + return code; + } + + public boolean isVector() { + return vector; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/Doc.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/Doc.java new file mode 100644 index 000000000..d67ce0cc5 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/Doc.java @@ -0,0 +1,99 @@ +package org.zvec; + +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public final class Doc { + private final String id; + private Double score; + private final LinkedHashMap fields = new LinkedHashMap<>(); + private final LinkedHashMap vectors = new LinkedHashMap<>(); + private final LinkedHashSet nullFields = new LinkedHashSet<>(); + + private Doc(String id) { + this.id = Objects.requireNonNull(id, "id"); + } + + public static Doc of(String id) { + return new Doc(id); + } + + public static Doc result(String id, double score) { + Doc doc = new Doc(id); + doc.score = score; + return doc; + } + + public Doc field(String name, String value) { + String fieldName = Objects.requireNonNull(name, "name"); + fields.put(fieldName, Objects.requireNonNull(value, "value")); + vectors.remove(fieldName); + nullFields.remove(fieldName); + return this; + } + + public Doc field(String name, boolean value) { + String fieldName = Objects.requireNonNull(name, "name"); + fields.put(fieldName, value); + vectors.remove(fieldName); + nullFields.remove(fieldName); + return this; + } + + public Doc field(String name, long value) { + String fieldName = Objects.requireNonNull(name, "name"); + fields.put(fieldName, value); + vectors.remove(fieldName); + nullFields.remove(fieldName); + return this; + } + + public Doc field(String name, double value) { + String fieldName = Objects.requireNonNull(name, "name"); + fields.put(fieldName, value); + vectors.remove(fieldName); + nullFields.remove(fieldName); + return this; + } + + public Doc vector(String name, float[] values) { + String vectorName = Objects.requireNonNull(name, "name"); + Objects.requireNonNull(values, "values"); + fields.remove(vectorName); + vectors.put(vectorName, values.clone()); + nullFields.remove(vectorName); + return this; + } + + public Doc nullField(String name) { + String fieldName = Objects.requireNonNull(name, "name"); + fields.remove(fieldName); + vectors.remove(fieldName); + nullFields.add(fieldName); + return this; + } + + public String id() { + return id; + } + + public Double score() { + return score; + } + + public Map fields() { + return Collections.unmodifiableMap(fields); + } + + public Map vectors() { + return Collections.unmodifiableMap(vectors); + } + + public Set nullFields() { + return Collections.unmodifiableSet(nullFields); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/FieldSchema.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/FieldSchema.java new file mode 100644 index 000000000..c5d3205ed --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/FieldSchema.java @@ -0,0 +1,60 @@ +package org.zvec; + +import java.util.Objects; + +public final class FieldSchema { + private final String name; + private final DataType dataType; + private final boolean nullable; + + public FieldSchema(String name, DataType dataType, boolean nullable) { + this.name = Objects.requireNonNull(name, "name"); + this.dataType = Objects.requireNonNull(dataType, "dataType"); + this.nullable = nullable; + if (dataType.isVector()) { + throw new IllegalArgumentException("FieldSchema requires a scalar data type"); + } + } + + public String name() { + return name; + } + + public DataType dataType() { + return dataType; + } + + public boolean nullable() { + return nullable; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof FieldSchema)) { + return false; + } + FieldSchema other = (FieldSchema) obj; + return nullable == other.nullable + && name.equals(other.name) + && dataType == other.dataType; + } + + @Override + public int hashCode() { + return Objects.hash(name, dataType, nullable); + } + + @Override + public String toString() { + return "FieldSchema[name=" + + name + + ", dataType=" + + dataType + + ", nullable=" + + nullable + + "]"; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/HnswIndexParams.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/HnswIndexParams.java new file mode 100644 index 000000000..ff8ed7cf9 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/HnswIndexParams.java @@ -0,0 +1,51 @@ +package org.zvec; + +import java.util.Objects; + +public final class HnswIndexParams { + private final int m; + private final int efConstruction; + + public HnswIndexParams(int m, int efConstruction) { + requirePositive(m, "m"); + requirePositive(efConstruction, "efConstruction"); + this.m = m; + this.efConstruction = efConstruction; + } + + public int m() { + return m; + } + + public int efConstruction() { + return efConstruction; + } + + private static void requirePositive(int value, String name) { + if (value <= 0) { + throw new IllegalArgumentException(name + " must be > 0"); + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof HnswIndexParams)) { + return false; + } + HnswIndexParams other = (HnswIndexParams) obj; + return m == other.m && efConstruction == other.efConstruction; + } + + @Override + public int hashCode() { + return Objects.hash(m, efConstruction); + } + + @Override + public String toString() { + return "HnswIndexParams[m=" + m + ", efConstruction=" + efConstruction + "]"; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/HnswQueryParams.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/HnswQueryParams.java new file mode 100644 index 000000000..eb86e2204 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/HnswQueryParams.java @@ -0,0 +1,76 @@ +package org.zvec; + +import java.util.Objects; + +public final class HnswQueryParams { + private final int ef; + private final float radius; + private final boolean linear; + private final boolean usingRefiner; + + public HnswQueryParams(int ef, float radius, boolean linear, boolean usingRefiner) { + requirePositive(ef, "ef"); + if (radius < 0.0f) { + throw new IllegalArgumentException("radius must be >= 0"); + } + this.ef = ef; + this.radius = radius; + this.linear = linear; + this.usingRefiner = usingRefiner; + } + + public int ef() { + return ef; + } + + public float radius() { + return radius; + } + + public boolean linear() { + return linear; + } + + public boolean usingRefiner() { + return usingRefiner; + } + + private static void requirePositive(int value, String name) { + if (value <= 0) { + throw new IllegalArgumentException(name + " must be > 0"); + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof HnswQueryParams)) { + return false; + } + HnswQueryParams other = (HnswQueryParams) obj; + return ef == other.ef + && Float.compare(radius, other.radius) == 0 + && linear == other.linear + && usingRefiner == other.usingRefiner; + } + + @Override + public int hashCode() { + return Objects.hash(ef, radius, linear, usingRefiner); + } + + @Override + public String toString() { + return "HnswQueryParams[ef=" + + ef + + ", radius=" + + radius + + ", linear=" + + linear + + ", usingRefiner=" + + usingRefiner + + "]"; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/TuningProfile.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/TuningProfile.java new file mode 100644 index 000000000..673baeda8 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/TuningProfile.java @@ -0,0 +1,7 @@ +package org.zvec; + +public enum TuningProfile { + FAST, + BALANCED, + ACCURATE +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/VectorQuery.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/VectorQuery.java new file mode 100644 index 000000000..3f7c61708 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/VectorQuery.java @@ -0,0 +1,106 @@ +package org.zvec; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +public final class VectorQuery { + private final String fieldName; + private final float[] queryVector; + private int topK = 10; + private boolean includeVector; + private String filter; + private boolean outputFieldsSpecified; + private List outputFields = List.of(); + private HnswQueryParams hnswQueryParams; + private TuningProfile tuningProfile; + + private VectorQuery(String fieldName, float[] queryVector) { + this.fieldName = Objects.requireNonNull(fieldName, "fieldName"); + this.queryVector = Objects.requireNonNull(queryVector, "queryVector").clone(); + if (this.queryVector.length == 0) { + throw new IllegalArgumentException("queryVector must not be empty"); + } + } + + public static VectorQuery of(String fieldName, float[] queryVector) { + return new VectorQuery(fieldName, queryVector); + } + + public VectorQuery topK(int topK) { + if (topK <= 0) { + throw new IllegalArgumentException("topK must be positive"); + } + this.topK = topK; + return this; + } + + public VectorQuery outputFields(String... fields) { + Objects.requireNonNull(fields, "fields"); + List output = new ArrayList<>(fields.length); + for (String field : fields) { + output.add(Objects.requireNonNull(field, "field")); + } + this.outputFieldsSpecified = true; + this.outputFields = List.copyOf(output); + return this; + } + + public VectorQuery includeVector(boolean includeVector) { + this.includeVector = includeVector; + return this; + } + + public VectorQuery filter(String filter) { + this.filter = filter; + return this; + } + + public VectorQuery hnsw(HnswQueryParams params) { + this.hnswQueryParams = Objects.requireNonNull(params, "params"); + this.tuningProfile = null; + return this; + } + + public VectorQuery withTuningProfile(TuningProfile profile) { + this.tuningProfile = Objects.requireNonNull(profile, "profile"); + this.hnswQueryParams = null; + return this; + } + + public String fieldName() { + return fieldName; + } + + public float[] queryVector() { + return queryVector.clone(); + } + + public int topK() { + return topK; + } + + public boolean includeVector() { + return includeVector; + } + + public String filter() { + return filter; + } + + public boolean outputFieldsSpecified() { + return outputFieldsSpecified; + } + + public List outputFields() { + return outputFields; + } + + public HnswQueryParams hnswQueryParams() { + return hnswQueryParams; + } + + public TuningProfile tuningProfile() { + return tuningProfile; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/VectorSchema.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/VectorSchema.java new file mode 100644 index 000000000..4f75e7726 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/VectorSchema.java @@ -0,0 +1,127 @@ +package org.zvec; + +import java.util.Objects; + +public final class VectorSchema { + private final String name; + private final DataType dataType; + private final int dimension; + private final HnswIndexParams hnswIndexParams; + private final TuningProfile tuningProfile; + private final Long expectedDocCount; + + public VectorSchema(String name, DataType dataType, int dimension) { + this(name, dataType, dimension, null, null, null); + } + + public VectorSchema( + String name, + DataType dataType, + int dimension, + HnswIndexParams hnswIndexParams, + TuningProfile tuningProfile, + Long expectedDocCount) { + this.name = Objects.requireNonNull(name, "name"); + this.dataType = Objects.requireNonNull(dataType, "dataType"); + if (!dataType.isVector()) { + throw new IllegalArgumentException("VectorSchema requires a vector data type"); + } + if (dimension <= 0) { + throw new IllegalArgumentException("dimension must be greater than 0"); + } + if (expectedDocCount != null && expectedDocCount <= 0L) { + throw new IllegalArgumentException("expectedDocCount must be > 0"); + } + this.dimension = dimension; + this.hnswIndexParams = hnswIndexParams; + this.tuningProfile = tuningProfile; + this.expectedDocCount = expectedDocCount; + } + + public String name() { + return name; + } + + public DataType dataType() { + return dataType; + } + + public int dimension() { + return dimension; + } + + public HnswIndexParams hnswIndexParams() { + return hnswIndexParams; + } + + public TuningProfile tuningProfile() { + return tuningProfile; + } + + public Long expectedDocCount() { + return expectedDocCount; + } + + public VectorSchema withHnswIndex(HnswIndexParams params) { + return new VectorSchema( + name, dataType, dimension, Objects.requireNonNull(params, "params"), null, null); + } + + public VectorSchema withTuningProfile(TuningProfile profile) { + return new VectorSchema( + name, dataType, dimension, null, Objects.requireNonNull(profile, "profile"), null); + } + + public VectorSchema withTuningProfile(TuningProfile profile, long expectedDocCount) { + return new VectorSchema( + name, + dataType, + dimension, + null, + Objects.requireNonNull(profile, "profile"), + expectedDocCount); + } + + public VectorSchema withExpectedDocCount(long expectedDocCount) { + return new VectorSchema(name, dataType, dimension, hnswIndexParams, tuningProfile, expectedDocCount); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof VectorSchema)) { + return false; + } + VectorSchema other = (VectorSchema) obj; + return dimension == other.dimension + && name.equals(other.name) + && dataType == other.dataType + && Objects.equals(hnswIndexParams, other.hnswIndexParams) + && tuningProfile == other.tuningProfile + && Objects.equals(expectedDocCount, other.expectedDocCount); + } + + @Override + public int hashCode() { + return Objects.hash(name, dataType, dimension, hnswIndexParams, tuningProfile, expectedDocCount); + } + + @Override + public String toString() { + return "VectorSchema[name=" + + name + + ", dataType=" + + dataType + + ", dimension=" + + dimension + + ", hnswIndexParams=" + + hnswIndexParams + + ", tuningProfile=" + + tuningProfile + + ", expectedDocCount=" + + expectedDocCount + + "]"; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/Zvec.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/Zvec.java new file mode 100644 index 000000000..cbc3a2770 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/Zvec.java @@ -0,0 +1,136 @@ +package org.zvec; + +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Objects; +import org.zvec.crypto.EncryptedSchema; +import org.zvec.crypto.EncryptionMetadata; +import org.zvec.crypto.KeyProvider; +import org.zvec.crypto.SidecarMetadata; +import org.zvec.internal.NativeBackend; +import org.zvec.internal.NativeBackends; +import org.zvec.internal.NativeOpenResult; +import org.zvec.internal.SchemaMetadataStore; + +public final class Zvec { + private Zvec() {} + + public static void ensureInitialized() { + NativeBackends.backend().ensureInitialized(); + } + + public static Collection createAndOpen(String path, CollectionSchema schema) { + ensureInitialized(); + Objects.requireNonNull(path, "path"); + Objects.requireNonNull(schema, "schema"); + + EncryptionMetadata meta = schema.encryption().orElse(null); + if (meta != null && !meta.encryptedFieldNames().isEmpty()) { + java.util.Map embedded = + schema.embeddedKeyProviders().orElse(java.util.Map.of()); + if (!embedded.keySet().containsAll(meta.encryptedFieldNames())) { + throw new org.zvec.crypto.EncryptedCollectionException( + "schema declares encrypted fields without embedded keys; " + + "use Zvec.createAndOpen(path, schema, provider)"); + } + KeyProvider composite = compose(embedded); + return createAndOpen(path, schema, composite); + } + return createBackend(path, schema); + } + + private static KeyProvider compose(java.util.Map embedded) { + return keyId -> { + for (KeyProvider p : embedded.values()) { + byte[] k = p.resolve(keyId); + if (k != null) return k; + } + return null; + }; + } + + public static Collection createAndOpen(String path, CollectionSchema schema, KeyProvider provider) { + ensureInitialized(); + Objects.requireNonNull(path, "path"); + Objects.requireNonNull(schema, "schema"); + Objects.requireNonNull(provider, "provider"); + + EncryptionMetadata meta = schema.encryption().orElse(EncryptionMetadata.empty(schema.name())); + + // Reconcile before native open so misconfig fails fast without creating an orphan directory. + EncryptedSchema es = meta.encryptedFieldNames().isEmpty() + ? EncryptedSchema.NONE + : EncryptedSchema.reconcile(schema, meta, provider); + + // Native open creates the collection directory; sidecars are written into it afterward. + Collection col = createBackend(path, schema); + try { + SidecarMetadata.write(Paths.get(path), meta); + col.attachEncryption(es); + return col; + } catch (RuntimeException e) { + try { col.close(); } catch (RuntimeException ignored) {} + throw e; + } + } + + public static Collection open(String path) { + ensureInitialized(); + Collection col = openBackend(Objects.requireNonNull(path, "path")); + try { + java.util.Optional meta = + SidecarMetadata.read(java.nio.file.Paths.get(path)); + if (meta.isPresent() && !meta.get().encryptedFieldNames().isEmpty()) { + throw new org.zvec.crypto.EncryptedCollectionException( + "collection at '" + path + "' has encrypted fields; use Zvec.openWithKeys"); + } + return col; + } catch (RuntimeException e) { + try { col.close(); } catch (RuntimeException ignored) {} + throw e; + } + } + + public static Collection openWithKeys(String path, KeyProvider provider) { + Objects.requireNonNull(path, "path"); + Objects.requireNonNull(provider, "provider"); + ensureInitialized(); + Collection col = openBackend(path); + try { + java.util.Optional meta = SidecarMetadata.read(java.nio.file.Paths.get(path)); + if (meta.isPresent() && !meta.get().encryptedFieldNames().isEmpty()) { + EncryptedSchema es = EncryptedSchema.reconcile(col.schema(), meta.get(), provider); + col.attachEncryption(es); + } + return col; + } catch (RuntimeException e) { + try { col.close(); } catch (RuntimeException ignored) {} + throw e; + } + } + + private static Collection openBackend(String path) { + NativeBackend backend = NativeBackends.backend(); + NativeOpenResult result = backend.open(path); + try { + CollectionSchema querySchema = result.querySchema(); + CollectionSchema publicSchema = SchemaMetadataStore.merge(path, querySchema); + return new Collection(backend, result.handle(), publicSchema, querySchema, path); + } catch (RuntimeException e) { + backend.close(result.handle()); + throw e; + } + } + + private static Collection createBackend(String path, CollectionSchema schema) { + NativeBackend backend = NativeBackends.backend(); + NativeOpenResult result = backend.createAndOpen(path, schema); + try { + SchemaMetadataStore.write(path, schema); + return new Collection(backend, result.handle(), schema, result.querySchema(), path); + } catch (RuntimeException e) { + backend.close(result.handle()); + throw e; + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/ZvecSchemas.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/ZvecSchemas.java new file mode 100644 index 000000000..7fde1e242 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/ZvecSchemas.java @@ -0,0 +1,165 @@ +package org.zvec; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.zvec.crypto.EncryptionMetadata; +import org.zvec.crypto.EncryptionSpec; +import org.zvec.crypto.KeyProvider; +import org.zvec.crypto.SingletonKeyProvider; +import org.zvec.crypto.UnsupportedFieldTypeException; + +public final class ZvecSchemas { + private ZvecSchemas() {} + + public static Builder collection(String name) { + return new Builder(name); + } + + public static final class Builder { + private final String name; + private final List fields = new ArrayList<>(); + private final List vectors = new ArrayList<>(); + private int activeVectorIndex = -1; + private String activeFieldName = null; + private DataType activeFieldType = null; + private final LinkedHashMap encryptionByField = new LinkedHashMap<>(); + private final LinkedHashMap embeddedProvidersByField = new LinkedHashMap<>(); + + private Builder(String name) { + this.name = Objects.requireNonNull(name, "name"); + } + + public Builder string(String name) { + fields.add(new FieldSchema(name, DataType.STRING, false)); + activeFieldName = name; + activeFieldType = DataType.STRING; + return this; + } + + public Builder bool(String name) { + fields.add(new FieldSchema(name, DataType.BOOL, false)); + activeFieldName = name; + activeFieldType = DataType.BOOL; + return this; + } + + public Builder int64(String name) { + fields.add(new FieldSchema(name, DataType.INT64, false)); + activeFieldName = name; + activeFieldType = DataType.INT64; + return this; + } + + public Builder doubleField(String name) { + fields.add(new FieldSchema(name, DataType.DOUBLE, false)); + activeFieldName = name; + activeFieldType = DataType.DOUBLE; + return this; + } + + public Builder vector(String name, int dimension) { + vectors.add(new VectorSchema(name, DataType.VECTOR_FP32, dimension)); + activeVectorIndex = vectors.size() - 1; + activeFieldName = null; + activeFieldType = null; + return this; + } + + public Builder encrypted(String keyId) { + Objects.requireNonNull(keyId, "keyId"); + if (activeFieldName == null) { + throw new IllegalStateException("encrypted(...) must follow string(name)"); + } + if (activeFieldType != DataType.STRING) { + throw new UnsupportedFieldTypeException( + "v1 only supports encrypted STRING fields; '" + activeFieldName + + "' has type " + activeFieldType); + } + if (encryptionByField.containsKey(activeFieldName)) { + throw new IllegalStateException( + "field '" + activeFieldName + "' already marked encrypted"); + } + encryptionByField.put(activeFieldName, + new EncryptionSpec("AES-256-GCM", keyId, Instant.now(), null)); + return this; + } + + public Builder encrypted(String keyId, byte[] key) { + Objects.requireNonNull(key, "key"); + encrypted(keyId); // reuses all checks + spec creation + embeddedProvidersByField.put(activeFieldName, new SingletonKeyProvider(keyId, key)); + return this; + } + + public Builder expectedDocCount(long expectedDocCount) { + VectorSchema vector = + requireActiveVector("expectedDocCount(...) must follow vector(name, dimension)"); + vectors.set(activeVectorIndex, applyExpectedDocCount(vector, expectedDocCount)); + return this; + } + + public Builder fast() { + applyProfile("fast() must follow vector(name, dimension)", TuningProfile.FAST); + return this; + } + + public Builder balanced() { + applyProfile("balanced() must follow vector(name, dimension)", TuningProfile.BALANCED); + return this; + } + + public Builder accurate() { + applyProfile("accurate() must follow vector(name, dimension)", TuningProfile.ACCURATE); + return this; + } + + public CollectionSchema build() { + if (encryptionByField.isEmpty()) { + return new CollectionSchema(name, fields, vectors); + } + EncryptionMetadata meta = new EncryptionMetadata( + EncryptionMetadata.VERSION_V1, name, encryptionByField); + Map embedded = + embeddedProvidersByField.isEmpty() + ? null + : new LinkedHashMap<>(embeddedProvidersByField); + return new CollectionSchema(name, fields, vectors, meta, embedded); + } + + private void applyProfile(String errorMessage, TuningProfile profile) { + VectorSchema vector = requireActiveVector(errorMessage); + vectors.set(activeVectorIndex, applyProfile(vector, profile)); + } + + private VectorSchema requireActiveVector(String errorMessage) { + if (activeVectorIndex < 0) { + throw new IllegalStateException(errorMessage); + } + return vectors.get(activeVectorIndex); + } + + private static VectorSchema applyProfile(VectorSchema vector, TuningProfile profile) { + return new VectorSchema( + vector.name(), + vector.dataType(), + vector.dimension(), + null, + Objects.requireNonNull(profile, "profile"), + vector.expectedDocCount()); + } + + private static VectorSchema applyExpectedDocCount(VectorSchema vector, long expectedDocCount) { + return new VectorSchema( + vector.name(), + vector.dataType(), + vector.dimension(), + null, + vector.tuningProfile(), + expectedDocCount); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/ZvecSearch.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/ZvecSearch.java new file mode 100644 index 000000000..4704ee327 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/ZvecSearch.java @@ -0,0 +1,62 @@ +package org.zvec; + +import java.util.Objects; + +public final class ZvecSearch { + private ZvecSearch() {} + + public static Builder vector(String fieldName, float[] queryVector) { + return new Builder(VectorQuery.of(fieldName, queryVector)); + } + + public static final class Builder { + private final VectorQuery query; + + private Builder(VectorQuery query) { + this.query = Objects.requireNonNull(query, "query"); + } + + public Builder topK(int topK) { + query.topK(topK); + return this; + } + + public Builder fast() { + applyProfile(TuningProfile.FAST); + return this; + } + + public Builder balanced() { + applyProfile(TuningProfile.BALANCED); + return this; + } + + public Builder accurate() { + applyProfile(TuningProfile.ACCURATE); + return this; + } + + public Builder project(String... fields) { + query.outputFields(fields); + return this; + } + + public Builder includeVector() { + query.includeVector(true); + return this; + } + + public Builder filter(String filter) { + query.filter(filter); + return this; + } + + public VectorQuery build() { + return query; + } + + private void applyProfile(TuningProfile profile) { + query.withTuningProfile(profile); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AadEncoder.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AadEncoder.java new file mode 100644 index 000000000..af922fd94 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AadEncoder.java @@ -0,0 +1,25 @@ +package org.zvec.crypto; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +/** + * Length-prefixed AAD: u32_be(len) || utf8(value) for each of (id, fieldName, collectionName). + * Recomputed on both encrypt and decrypt; never stored in the envelope. + */ +public final class AadEncoder { + private AadEncoder() {} + + public static byte[] encode(String id, String fieldName, String collectionName) { + byte[] idBytes = Objects.requireNonNull(id, "id").getBytes(StandardCharsets.UTF_8); + byte[] fieldBytes = Objects.requireNonNull(fieldName, "fieldName").getBytes(StandardCharsets.UTF_8); + byte[] collBytes = Objects.requireNonNull(collectionName, "collectionName").getBytes(StandardCharsets.UTF_8); + + ByteBuffer buf = ByteBuffer.allocate(12 + idBytes.length + fieldBytes.length + collBytes.length); + buf.putInt(idBytes.length).put(idBytes); + buf.putInt(fieldBytes.length).put(fieldBytes); + buf.putInt(collBytes.length).put(collBytes); + return buf.array(); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/Aead.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/Aead.java new file mode 100644 index 000000000..ed3af70cd --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/Aead.java @@ -0,0 +1,10 @@ +package org.zvec.crypto; + +public interface Aead { + /** Returns the algorithm id used in the envelope alg byte. */ + int algId(); + + byte[] seal(byte[] key, byte[] nonce, byte[] plaintext, byte[] aad); + + byte[] open(byte[] key, byte[] nonce, byte[] ciphertext, byte[] aad); +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AesGcm256.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AesGcm256.java new file mode 100644 index 000000000..3ab3e3869 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AesGcm256.java @@ -0,0 +1,58 @@ +package org.zvec.crypto; + +import javax.crypto.Cipher; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.SecretKeySpec; +import java.security.GeneralSecurityException; + +public final class AesGcm256 implements Aead { + public static final int KEY_LEN = 32; + public static final int NONCE_LEN = 12; + public static final int TAG_BITS = 128; + + @Override + public int algId() { return Envelope.ALG_AES_256_GCM; } + + @Override + public byte[] seal(byte[] key, byte[] nonce, byte[] plaintext, byte[] aad) { + validate(key, nonce); + try { + Cipher c = Cipher.getInstance("AES/GCM/NoPadding"); + c.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(key, "AES"), + new GCMParameterSpec(TAG_BITS, nonce)); + if (aad != null && aad.length > 0) { + c.updateAAD(aad); + } + return c.doFinal(plaintext == null ? new byte[0] : plaintext); + } catch (GeneralSecurityException e) { + throw new EncryptionFailedException("AES-256-GCM seal failed", e); + } + } + + @Override + public byte[] open(byte[] key, byte[] nonce, byte[] ciphertext, byte[] aad) { + validate(key, nonce); + try { + Cipher c = Cipher.getInstance("AES/GCM/NoPadding"); + c.init(Cipher.DECRYPT_MODE, new SecretKeySpec(key, "AES"), + new GCMParameterSpec(TAG_BITS, nonce)); + if (aad != null && aad.length > 0) { + c.updateAAD(aad); + } + return c.doFinal(ciphertext); + } catch (javax.crypto.AEADBadTagException e) { + throw new AuthenticationFailedException("GCM tag mismatch", e); + } catch (GeneralSecurityException e) { + throw new AuthenticationFailedException("AES-256-GCM open failed: " + e.getClass().getSimpleName(), e); + } + } + + private static void validate(byte[] key, byte[] nonce) { + if (key == null || key.length != KEY_LEN) { + throw new IllegalArgumentException("AES-256 key must be " + KEY_LEN + " bytes, got " + (key == null ? 0 : key.length)); + } + if (nonce == null || nonce.length != NONCE_LEN) { + throw new IllegalArgumentException("nonce must be " + NONCE_LEN + " bytes, got " + (nonce == null ? 0 : nonce.length)); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AuthenticationFailedException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AuthenticationFailedException.java new file mode 100644 index 000000000..4f7c66de1 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/AuthenticationFailedException.java @@ -0,0 +1,13 @@ +package org.zvec.crypto; + +/** Thrown when AEAD authentication tag verification fails during decryption. */ +public final class AuthenticationFailedException extends DecryptionException { + + public AuthenticationFailedException(String message) { + super(message); + } + + public AuthenticationFailedException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/DecryptingProjector.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/DecryptingProjector.java new file mode 100644 index 000000000..f44fa8777 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/DecryptingProjector.java @@ -0,0 +1,98 @@ +package org.zvec.crypto; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.zvec.Doc; + +public final class DecryptingProjector { + private static final Aead AEAD = new AesGcm256(); + + private DecryptingProjector() {} + + public static List transform(List docs, EncryptedSchema es) { + if (es == EncryptedSchema.NONE || es.encryptedFieldNames().isEmpty()) { + return docs; + } + String collectionName = es.schema().name(); + Map keyCache = new HashMap<>(); + + List result = new ArrayList<>(docs.size()); + for (Doc input : docs) { + Doc out = input.score() == null ? Doc.of(input.id()) : Doc.result(input.id(), input.score()); + for (Map.Entry e : input.fields().entrySet()) { + String name = e.getKey(); + Object value = e.getValue(); + if (es.isEncrypted(name) && value instanceof String) { + String b64 = (String) value; + Envelope env = EnvelopeCodec.decodeBase64(b64); + if (env.alg() != AEAD.algId()) { + throw new EnvelopeFormatException( + "unsupported alg=0x" + Integer.toHexString(env.alg()) + + " (field='" + name + "' doc.id='" + input.id() + "')"); + } + if (env.payloadType() != Envelope.PAYLOAD_STRING) { + throw new EnvelopeFormatException( + "unsupported payload_type=0x" + Integer.toHexString(env.payloadType()) + + " (field='" + name + "' doc.id='" + input.id() + "')"); + } + byte[] key = keyCache.computeIfAbsent(env.keyId(), kid -> resolveKey(es, kid, input.id(), name)); + byte[] aad = AadEncoder.encode(input.id(), name, collectionName); + byte[] pt; + try { + pt = AEAD.open(key, env.nonce(), env.ciphertext(), aad); + } catch (AuthenticationFailedException afe) { + throw new AuthenticationFailedException( + "GCM tag mismatch decrypting field='" + name + + "' doc.id='" + input.id() + + "' keyId='" + env.keyId() + "'", afe); + } + out.field(name, new String(pt, StandardCharsets.UTF_8)); + } else { + assignField(out, name, value); + } + } + for (Map.Entry v : input.vectors().entrySet()) { + out.vector(v.getKey(), v.getValue()); + } + for (String n : input.nullFields()) { + out.nullField(n); + } + result.add(out); + } + return result; + } + + private static byte[] resolveKey(EncryptedSchema es, String keyId, String docId, String fieldName) { + byte[] k; + try { + k = es.keyProvider().resolve(keyId); + } catch (RuntimeException t) { + throw new KeyResolutionException( + "KeyProvider.resolve threw for keyId='" + keyId + + "' (field='" + fieldName + "' doc.id='" + docId + "')", t); + } + if (k == null) { + throw new KeyResolutionException( + "KeyProvider returned null for keyId='" + keyId + + "' (field='" + fieldName + "' doc.id='" + docId + "')"); + } + if (k.length != AesGcm256.KEY_LEN) { + throw new KeyResolutionException( + "KeyProvider returned " + k.length + "-byte key, expected " + + AesGcm256.KEY_LEN + " bytes for keyId='" + keyId + "'"); + } + return k; + } + + private static void assignField(Doc out, String name, Object value) { + if (value instanceof String) out.field(name, (String) value); + else if (value instanceof Boolean) out.field(name, ((Boolean) value).booleanValue()); + else if (value instanceof Long) out.field(name, ((Long) value).longValue()); + else if (value instanceof Double) out.field(name, ((Double) value).doubleValue()); + else if (value == null) out.nullField(name); + else throw new IllegalStateException("unexpected field value type: " + value.getClass()); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/DecryptionException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/DecryptionException.java new file mode 100644 index 000000000..4fbf16d94 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/DecryptionException.java @@ -0,0 +1,13 @@ +package org.zvec.crypto; + +/** Thrown when field decryption fails during a read operation. */ +public class DecryptionException extends EncryptionRuntimeException { + + public DecryptionException(String message) { + super(message); + } + + public DecryptionException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptedCollectionException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptedCollectionException.java new file mode 100644 index 000000000..af1c6d7b1 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptedCollectionException.java @@ -0,0 +1,9 @@ +package org.zvec.crypto; + +/** Thrown when a collection-level encryption constraint is violated. */ +public final class EncryptedCollectionException extends EncryptionConfigException { + + public EncryptedCollectionException(String message) { + super(message); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptedSchema.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptedSchema.java new file mode 100644 index 000000000..2c093e46b --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptedSchema.java @@ -0,0 +1,72 @@ +package org.zvec.crypto; + +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import org.zvec.CollectionSchema; +import org.zvec.DataType; +import org.zvec.FieldSchema; + +/** Bundle of CollectionSchema + EncryptionMetadata + KeyProvider. Constructed via reconcile. */ +public final class EncryptedSchema { + + /** Sentinel: collection has no encrypted fields. */ + public static final EncryptedSchema NONE = + new EncryptedSchema(null, EncryptionMetadata.empty("__none__"), null); + + private final CollectionSchema schema; + private final EncryptionMetadata metadata; + private final KeyProvider keyProvider; + + private EncryptedSchema(CollectionSchema schema, EncryptionMetadata metadata, KeyProvider keyProvider) { + this.schema = schema; + this.metadata = metadata; + this.keyProvider = keyProvider; + } + + public static EncryptedSchema reconcile( + CollectionSchema schema, EncryptionMetadata metadata, KeyProvider keyProvider) { + Objects.requireNonNull(schema, "schema"); + Objects.requireNonNull(metadata, "metadata"); + + if (metadata.encryptedFieldNames().isEmpty()) { + return NONE; + } + Objects.requireNonNull(keyProvider, "keyProvider"); + + if (!schema.name().equals(metadata.collectionName())) { + throw new EncryptionMetadataMismatchException( + "collection name mismatch: schema='" + schema.name() + + "' sidecar='" + metadata.collectionName() + "'"); + } + for (Map.Entry e : metadata.fields().entrySet()) { + String name = e.getKey(); + FieldSchema field = schema.field(name); + if (field == null) { + throw new EncryptionMetadataMismatchException( + "encrypted field '" + name + "' not present in schema"); + } + if (field.dataType() != DataType.STRING) { + throw new EncryptionMetadataMismatchException( + "encrypted field '" + name + "' has type " + field.dataType() + + " but v1 only supports STRING"); + } + } + return new EncryptedSchema(schema, metadata, keyProvider); + } + + public boolean isEncrypted(String fieldName) { + return metadata.isEncrypted(fieldName); + } + + public Set encryptedFieldNames() { return metadata.encryptedFieldNames(); } + + public String activeKeyId(String fieldName) { + EncryptionSpec spec = metadata.spec(fieldName); + return spec == null ? null : spec.activeKeyId(); + } + + public CollectionSchema schema() { return schema; } + public EncryptionMetadata metadata() { return metadata; } + public KeyProvider keyProvider() { return keyProvider; } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptingInsertor.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptingInsertor.java new file mode 100644 index 000000000..d6d64feef --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptingInsertor.java @@ -0,0 +1,83 @@ +package org.zvec.crypto; + +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.zvec.Doc; + +public final class EncryptingInsertor { + private static final SecureRandom RNG = new SecureRandom(); + private static final Aead AEAD = new AesGcm256(); + + private EncryptingInsertor() {} + + public static List transform(List docs, EncryptedSchema es) { + if (es == EncryptedSchema.NONE || es.encryptedFieldNames().isEmpty()) { + return docs; + } + String collectionName = es.schema().name(); + Map keyCache = new HashMap<>(); + + List result = new ArrayList<>(docs.size()); + for (Doc input : docs) { + Doc out = Doc.of(input.id()); + for (Map.Entry e : input.fields().entrySet()) { + String name = e.getKey(); + Object value = e.getValue(); + if (es.isEncrypted(name) && value instanceof String) { + String pt = (String) value; + String keyId = es.activeKeyId(name); + byte[] key = keyCache.computeIfAbsent(keyId, kid -> resolveKey(es, kid)); + byte[] nonce = new byte[Envelope.NONCE_LEN]; + RNG.nextBytes(nonce); + byte[] aad = AadEncoder.encode(input.id(), name, collectionName); + byte[] ct = AEAD.seal(key, nonce, pt.getBytes(StandardCharsets.UTF_8), aad); + Envelope env = new Envelope( + Envelope.VERSION_V1, AEAD.algId(), Envelope.PAYLOAD_STRING, + keyId, nonce, ct); + out.field(name, EnvelopeCodec.encodeBase64(env)); + } else { + assignField(out, name, value); + } + } + for (Map.Entry v : input.vectors().entrySet()) { + out.vector(v.getKey(), v.getValue()); + } + for (String n : input.nullFields()) { + out.nullField(n); + } + result.add(out); + } + return result; + } + + private static byte[] resolveKey(EncryptedSchema es, String keyId) { + byte[] k; + try { + k = es.keyProvider().resolve(keyId); + } catch (RuntimeException t) { + throw new KeyResolutionException("KeyProvider.resolve threw for keyId='" + keyId + "'", t); + } + if (k == null) { + throw new KeyResolutionException("KeyProvider returned null for keyId='" + keyId + "'"); + } + if (k.length != AesGcm256.KEY_LEN) { + throw new KeyResolutionException( + "KeyProvider returned " + k.length + "-byte key, expected " + + AesGcm256.KEY_LEN + " bytes for keyId='" + keyId + "'"); + } + return k; + } + + private static void assignField(Doc out, String name, Object value) { + if (value instanceof String) out.field(name, (String) value); + else if (value instanceof Boolean) out.field(name, ((Boolean) value).booleanValue()); + else if (value instanceof Long) out.field(name, ((Long) value).longValue()); + else if (value instanceof Double) out.field(name, ((Double) value).doubleValue()); + else if (value == null) out.nullField(name); + else throw new IllegalStateException("unexpected field value type: " + value.getClass()); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionConfigException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionConfigException.java new file mode 100644 index 000000000..1e16ff0f5 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionConfigException.java @@ -0,0 +1,13 @@ +package org.zvec.crypto; + +/** Base for configuration and structural encryption errors. */ +abstract class EncryptionConfigException extends EncryptionException { + + protected EncryptionConfigException(String message) { + super(message); + } + + protected EncryptionConfigException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionException.java new file mode 100644 index 000000000..88a79127b --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionException.java @@ -0,0 +1,13 @@ +package org.zvec.crypto; + +/** Root of the zvec encryption exception hierarchy. */ +public abstract class EncryptionException extends RuntimeException { + + protected EncryptionException(String message) { + super(message); + } + + protected EncryptionException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionFailedException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionFailedException.java new file mode 100644 index 000000000..0ae9bad30 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionFailedException.java @@ -0,0 +1,9 @@ +package org.zvec.crypto; + +/** Thrown when field encryption fails during a write operation. */ +public final class EncryptionFailedException extends EncryptionRuntimeException { + + public EncryptionFailedException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadata.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadata.java new file mode 100644 index 000000000..f2c6cce05 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadata.java @@ -0,0 +1,82 @@ +package org.zvec.crypto; + +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** Whole-collection encryption metadata, mirroring the sidecar JSON shape. */ +public final class EncryptionMetadata { + public static final int VERSION_V1 = 1; + + private final int version; + private final String collectionName; + private final Map fields; + + public EncryptionMetadata(int version, String collectionName, Map fields) { + this.version = version; + this.collectionName = Objects.requireNonNull(collectionName, "collectionName"); + Objects.requireNonNull(fields, "fields"); + if (version != VERSION_V1) { + throw new IllegalArgumentException("v1 only supports metadata version=" + VERSION_V1 + ", got " + version); + } + this.fields = Map.copyOf(fields); + } + + public int version() { + return version; + } + + public String collectionName() { + return collectionName; + } + + public Map fields() { + return fields; + } + + public static EncryptionMetadata empty(String collectionName) { + return new EncryptionMetadata(VERSION_V1, collectionName, Map.of()); + } + + public boolean isEncrypted(String fieldName) { + return fields.containsKey(fieldName); + } + + public EncryptionSpec spec(String fieldName) { + return fields.get(fieldName); + } + + public Set encryptedFieldNames() { + return fields.keySet(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof EncryptionMetadata)) { + return false; + } + EncryptionMetadata other = (EncryptionMetadata) obj; + return version == other.version + && collectionName.equals(other.collectionName) + && fields.equals(other.fields); + } + + @Override + public int hashCode() { + return Objects.hash(version, collectionName, fields); + } + + @Override + public String toString() { + return "EncryptionMetadata[version=" + + version + + ", collectionName=" + + collectionName + + ", fields=" + + fields + + "]"; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadataIOException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadataIOException.java new file mode 100644 index 000000000..8baf80536 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadataIOException.java @@ -0,0 +1,9 @@ +package org.zvec.crypto; + +/** Thrown when encryption metadata cannot be read from or written to storage. */ +public final class EncryptionMetadataIOException extends EncryptionConfigException { + + public EncryptionMetadataIOException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadataMismatchException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadataMismatchException.java new file mode 100644 index 000000000..6a389d6d1 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionMetadataMismatchException.java @@ -0,0 +1,9 @@ +package org.zvec.crypto; + +/** Thrown when encryption metadata does not match the expected schema. */ +public final class EncryptionMetadataMismatchException extends EncryptionConfigException { + + public EncryptionMetadataMismatchException(String message) { + super(message); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionRuntimeException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionRuntimeException.java new file mode 100644 index 000000000..a15ec1c7d --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionRuntimeException.java @@ -0,0 +1,13 @@ +package org.zvec.crypto; + +/** Base for runtime encryption/decryption errors. */ +abstract class EncryptionRuntimeException extends EncryptionException { + + protected EncryptionRuntimeException(String message) { + super(message); + } + + protected EncryptionRuntimeException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionSpec.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionSpec.java new file mode 100644 index 000000000..0b2d74591 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EncryptionSpec.java @@ -0,0 +1,76 @@ +package org.zvec.crypto; + +import java.time.Instant; +import java.util.Objects; + +/** Per-field crypto metadata as persisted in the sidecar. */ +public final class EncryptionSpec { + public static final String ALG_AES_256_GCM = "AES-256-GCM"; + + private final String alg; + private final String activeKeyId; + private final Instant createdAt; + private final Instant rotatedAt; + + public EncryptionSpec(String alg, String activeKeyId, Instant createdAt, Instant rotatedAt) { + this.alg = Objects.requireNonNull(alg, "alg"); + this.activeKeyId = Objects.requireNonNull(activeKeyId, "activeKeyId"); + this.createdAt = Objects.requireNonNull(createdAt, "createdAt"); + this.rotatedAt = rotatedAt; + if (!ALG_AES_256_GCM.equals(alg)) { + throw new IllegalArgumentException("v1 only supports alg=" + ALG_AES_256_GCM + ", got " + alg); + } + if (activeKeyId.isEmpty()) { + throw new IllegalArgumentException("activeKeyId must not be empty"); + } + } + + public String alg() { + return alg; + } + + public String activeKeyId() { + return activeKeyId; + } + + public Instant createdAt() { + return createdAt; + } + + public Instant rotatedAt() { + return rotatedAt; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof EncryptionSpec)) { + return false; + } + EncryptionSpec other = (EncryptionSpec) obj; + return alg.equals(other.alg) + && activeKeyId.equals(other.activeKeyId) + && createdAt.equals(other.createdAt) + && Objects.equals(rotatedAt, other.rotatedAt); + } + + @Override + public int hashCode() { + return Objects.hash(alg, activeKeyId, createdAt, rotatedAt); + } + + @Override + public String toString() { + return "EncryptionSpec[alg=" + + alg + + ", activeKeyId=" + + activeKeyId + + ", createdAt=" + + createdAt + + ", rotatedAt=" + + rotatedAt + + "]"; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/Envelope.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/Envelope.java new file mode 100644 index 000000000..de39de754 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/Envelope.java @@ -0,0 +1,96 @@ +package org.zvec.crypto; + +import java.util.Objects; + +/** Parsed envelope. Field positions correspond to the binary layout in EnvelopeCodec. */ +public final class Envelope { + public static final int VERSION_V1 = 0x01; + public static final int ALG_AES_256_GCM = 0x01; + public static final int PAYLOAD_STRING = 0x00; + public static final int NONCE_LEN = 12; + + private final int version; + private final int alg; + private final int payloadType; + private final String keyId; + private final byte[] nonce; + private final byte[] ciphertext; + + public Envelope( + int version, + int alg, + int payloadType, + String keyId, + byte[] nonce, + byte[] ciphertext) { + this.version = version; + this.alg = alg; + this.payloadType = payloadType; + this.keyId = Objects.requireNonNull(keyId, "keyId"); + this.nonce = Objects.requireNonNull(nonce, "nonce"); + this.ciphertext = Objects.requireNonNull(ciphertext, "ciphertext"); + } + + public int version() { + return version; + } + + public int alg() { + return alg; + } + + public int payloadType() { + return payloadType; + } + + public String keyId() { + return keyId; + } + + public byte[] nonce() { + return nonce; + } + + public byte[] ciphertext() { + return ciphertext; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Envelope)) { + return false; + } + Envelope other = (Envelope) obj; + return version == other.version + && alg == other.alg + && payloadType == other.payloadType + && keyId.equals(other.keyId) + && Objects.equals(nonce, other.nonce) + && Objects.equals(ciphertext, other.ciphertext); + } + + @Override + public int hashCode() { + return Objects.hash(version, alg, payloadType, keyId, nonce, ciphertext); + } + + @Override + public String toString() { + return "Envelope[version=" + + version + + ", alg=" + + alg + + ", payloadType=" + + payloadType + + ", keyId=" + + keyId + + ", nonce=" + + nonce + + ", ciphertext=" + + ciphertext + + "]"; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EnvelopeCodec.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EnvelopeCodec.java new file mode 100644 index 000000000..f1c196849 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EnvelopeCodec.java @@ -0,0 +1,72 @@ +package org.zvec.crypto; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +public final class EnvelopeCodec { + private static final Base64.Encoder B64_ENC = Base64.getUrlEncoder().withoutPadding(); + private static final Base64.Decoder B64_DEC = Base64.getUrlDecoder(); + + private EnvelopeCodec() {} + + public static byte[] encode(Envelope env) { + byte[] kid = env.keyId().getBytes(StandardCharsets.UTF_8); + if (kid.length < 1 || kid.length > 255) { + throw new IllegalArgumentException("keyId UTF-8 length must be 1..255, got " + kid.length); + } + if (env.nonce().length != Envelope.NONCE_LEN) { + throw new IllegalArgumentException("nonce must be " + Envelope.NONCE_LEN + " bytes"); + } + int total = 4 + kid.length + Envelope.NONCE_LEN + env.ciphertext().length; + ByteBuffer buf = ByteBuffer.allocate(total); + buf.put((byte) env.version()); + buf.put((byte) env.alg()); + buf.put((byte) env.payloadType()); + buf.put((byte) kid.length); + buf.put(kid); + buf.put(env.nonce()); + buf.put(env.ciphertext()); + return buf.array(); + } + + public static Envelope decode(byte[] data) { + if (data == null || data.length < 4) { + throw new EnvelopeFormatException("envelope truncated: length=" + (data == null ? 0 : data.length)); + } + int version = data[0] & 0xff; + int alg = data[1] & 0xff; + int payload = data[2] & 0xff; + int kidLen = data[3] & 0xff; + + if (version != Envelope.VERSION_V1) { + throw new EnvelopeFormatException("unsupported envelope version=0x" + Integer.toHexString(version)); + } + if (kidLen == 0) { + throw new EnvelopeFormatException("envelope keyId_len cannot be zero"); + } + int kidEnd = 4 + kidLen; + int nonceEnd = kidEnd + Envelope.NONCE_LEN; + if (data.length < nonceEnd) { + throw new EnvelopeFormatException("envelope truncated: expected at least " + nonceEnd + " bytes, got " + data.length); + } + String keyId = new String(data, 4, kidLen, StandardCharsets.UTF_8); + byte[] nonce = java.util.Arrays.copyOfRange(data, kidEnd, nonceEnd); + byte[] ct = java.util.Arrays.copyOfRange(data, nonceEnd, data.length); + return new Envelope(version, alg, payload, keyId, nonce, ct); + } + + public static String encodeBase64(Envelope env) { + return B64_ENC.encodeToString(encode(env)); + } + + public static Envelope decodeBase64(String s) { + byte[] raw; + try { + raw = B64_DEC.decode(s); + } catch (IllegalArgumentException e) { + throw new EnvelopeFormatException("envelope base64 decode failed: " + e.getMessage()); + } + return decode(raw); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EnvelopeFormatException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EnvelopeFormatException.java new file mode 100644 index 000000000..f850c55dc --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/EnvelopeFormatException.java @@ -0,0 +1,9 @@ +package org.zvec.crypto; + +/** Thrown when an encrypted envelope cannot be parsed due to format errors. */ +public final class EnvelopeFormatException extends DecryptionException { + + public EnvelopeFormatException(String message) { + super(message); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/FilterFieldScanner.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/FilterFieldScanner.java new file mode 100644 index 000000000..77e4ecb27 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/FilterFieldScanner.java @@ -0,0 +1,56 @@ +package org.zvec.crypto; + +import java.util.LinkedHashSet; +import java.util.Set; + +/** + * Extracts identifier-like tokens from a filter string, skipping string + * literals (single- or double-quoted, with backslash escapes). Returns + * every identifier token it sees, including SQL keywords like AND/OR — + * the caller intersects with the encrypted-field set, so keywords don't + * cause false positives. + */ +public final class FilterFieldScanner { + private FilterFieldScanner() {} + + public static Set referencedFields(String filter) { + Set ids = new LinkedHashSet<>(); + if (filter == null || filter.isEmpty()) return ids; + + int i = 0; + int n = filter.length(); + while (i < n) { + char c = filter.charAt(i); + if (c == '\'' || c == '"') { + i = skipStringLiteral(filter, i, c); + } else if (isIdentStart(c)) { + int start = i; + while (i < n && isIdentPart(filter.charAt(i))) i++; + ids.add(filter.substring(start, i)); + } else { + i++; + } + } + return ids; + } + + private static int skipStringLiteral(String s, int start, char quote) { + int i = start + 1; + int n = s.length(); + while (i < n) { + char c = s.charAt(i); + if (c == '\\' && i + 1 < n) { i += 2; continue; } + if (c == quote) return i + 1; + i++; + } + return n; + } + + private static boolean isIdentStart(char c) { + return Character.isLetter(c) || c == '_'; + } + + private static boolean isIdentPart(char c) { + return Character.isLetterOrDigit(c) || c == '_'; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/KeyProvider.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/KeyProvider.java new file mode 100644 index 000000000..0acb48a22 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/KeyProvider.java @@ -0,0 +1,23 @@ +package org.zvec.crypto; + +/** + * Caller-implemented bridge between zvec's encryption layer and a key store. + * The library never caches resolved keys across calls; callers manage their own + * caching, KMS integration, rotation policy, and thread safety. + */ +public interface KeyProvider { + + /** + * Resolve a keyId (as written into an envelope or scheduled in the sidecar) + * to its raw bytes. Must return a 32-byte AES-256 key, or null if unknown. + * Implementations may also throw — the library wraps any throwable in + * {@link KeyResolutionException}. + */ + byte[] resolve(String keyId); + + /** + * Optional liveness check; the library does not call this. Provided for + * caller-side rotation logic. + */ + default boolean isActive(String keyId) { return true; } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/KeyResolutionException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/KeyResolutionException.java new file mode 100644 index 000000000..3959268a1 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/KeyResolutionException.java @@ -0,0 +1,13 @@ +package org.zvec.crypto; + +/** Thrown when a key cannot be resolved from the key provider. */ +public final class KeyResolutionException extends EncryptionRuntimeException { + + public KeyResolutionException(String message) { + super(message); + } + + public KeyResolutionException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SidecarJson.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SidecarJson.java new file mode 100644 index 000000000..243ae617d --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SidecarJson.java @@ -0,0 +1,212 @@ +package org.zvec.crypto; + +import java.time.Instant; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.TreeMap; + +/** + * Minimal hand-rolled JSON codec for the _zvec_enc.json shape only. + * Not a general-purpose JSON parser. Strings, integers, and nested objects + * are the only types we emit or accept. + */ +final class SidecarJson { + private SidecarJson() {} + + static String write(EncryptionMetadata meta) { + StringBuilder sb = new StringBuilder(256); + sb.append("{\n"); + sb.append(" \"version\": ").append(meta.version()).append(",\n"); + sb.append(" \"collection_name\": ").append(quote(meta.collectionName())).append(",\n"); + sb.append(" \"fields\": {"); + Map sorted = new TreeMap<>(meta.fields()); + boolean first = true; + for (Map.Entry e : sorted.entrySet()) { + if (!first) sb.append(","); + first = false; + sb.append("\n ").append(quote(e.getKey())).append(": {"); + EncryptionSpec s = e.getValue(); + sb.append("\n \"alg\": ").append(quote(s.alg())).append(","); + sb.append("\n \"active_key_id\": ").append(quote(s.activeKeyId())).append(","); + sb.append("\n \"created_at\": ").append(quote(s.createdAt().toString())); + if (s.rotatedAt() != null) { + sb.append(",\n \"rotated_at\": ").append(quote(s.rotatedAt().toString())); + } + sb.append("\n }"); + } + if (!sorted.isEmpty()) sb.append("\n "); + sb.append("}\n}\n"); + return sb.toString(); + } + + static EncryptionMetadata read(String json) { + Parser p = new Parser(json); + try { + Map root = p.parseObject(); + p.requireEof(); + int version = ((Number) require(root, "version")).intValue(); + String coll = (String) require(root, "collection_name"); + @SuppressWarnings("unchecked") + Map fieldsObj = (Map) require(root, "fields"); + Map fields = new LinkedHashMap<>(); + for (Map.Entry e : fieldsObj.entrySet()) { + @SuppressWarnings("unchecked") + Map spec = (Map) e.getValue(); + String alg = (String) require(spec, "alg"); + String activeKeyId = (String) require(spec, "active_key_id"); + Instant created = Instant.parse((String) require(spec, "created_at")); + Instant rotated = spec.containsKey("rotated_at") ? Instant.parse((String) spec.get("rotated_at")) : null; + fields.put(e.getKey(), new EncryptionSpec(alg, activeKeyId, created, rotated)); + } + return new EncryptionMetadata(version, coll, fields); + } catch (IllegalArgumentException e) { + throw e; // version / spec validation + } catch (EncryptionMetadataIOException e) { + throw e; // already correctly typed and messaged + } catch (Exception e) { + throw new EncryptionMetadataIOException("sidecar JSON parse failed: " + e.getMessage(), e); + } + } + + private static Object require(Map m, String k) { + Object v = m.get(k); + if (v == null) throw new EncryptionMetadataIOException("sidecar missing key: " + k, null); + return v; + } + + private static String quote(String s) { + StringBuilder out = new StringBuilder(s.length() + 2); + out.append('"'); + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + switch (c) { + case '"': + out.append("\\\""); + break; + case '\\': + out.append("\\\\"); + break; + case '\n': + out.append("\\n"); + break; + case '\r': + out.append("\\r"); + break; + case '\t': + out.append("\\t"); + break; + default: + if (c < 0x20) out.append(String.format("\\u%04x", (int) c)); + else out.append(c); + break; + } + } + out.append('"'); + return out.toString(); + } + + /** Single-pass recursive descent for the strict shape we emit. */ + private static final class Parser { + private final String src; + private int i; + + Parser(String src) { this.src = src; } + + Map parseObject() { + skipWs(); + expect('{'); + Map m = new LinkedHashMap<>(); + skipWs(); + if (peek() == '}') { i++; return m; } + while (true) { + skipWs(); + String key = parseString(); + skipWs(); + expect(':'); + skipWs(); + m.put(key, parseValue()); + skipWs(); + if (peek() == ',') { i++; continue; } + expect('}'); + return m; + } + } + + Object parseValue() { + char c = peek(); + if (c == '"') return parseString(); + if (c == '{') return parseObject(); + if (c == '-' || (c >= '0' && c <= '9')) return parseNumber(); + throw new IllegalStateException("unexpected character at " + i + ": " + c); + } + + String parseString() { + expect('"'); + StringBuilder out = new StringBuilder(); + while (i < src.length()) { + char c = src.charAt(i++); + if (c == '"') return out.toString(); + if (c == '\\') { + char esc = src.charAt(i++); + switch (esc) { + case '"': + out.append('"'); + break; + case '\\': + out.append('\\'); + break; + case 'n': + out.append('\n'); + break; + case 'r': + out.append('\r'); + break; + case 't': + out.append('\t'); + break; + case 'u': + out.append((char) Integer.parseInt(src.substring(i, i + 4), 16)); + i += 4; + break; + default: + throw new IllegalStateException("bad escape \\" + esc); + } + } else { + out.append(c); + } + } + throw new IllegalStateException("unterminated string"); + } + + Number parseNumber() { + int start = i; + if (peek() == '-') i++; + while (i < src.length() && Character.isDigit(src.charAt(i))) i++; + return Long.parseLong(src.substring(start, i)); + } + + void expect(char c) { + if (i >= src.length() || src.charAt(i) != c) { + throw new IllegalStateException("expected '" + c + "' at " + i); + } + i++; + } + + void skipWs() { + while (i < src.length() && Character.isWhitespace(src.charAt(i))) i++; + } + + void requireEof() { + skipWs(); + if (i < src.length()) { + throw new IllegalStateException( + "unexpected trailing content at offset " + i + ": '" + src.charAt(i) + "'"); + } + } + + char peek() { + if (i >= src.length()) throw new IllegalStateException("unexpected end of input"); + return src.charAt(i); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SidecarMetadata.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SidecarMetadata.java new file mode 100644 index 000000000..1b68c344b --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SidecarMetadata.java @@ -0,0 +1,48 @@ +package org.zvec.crypto; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.Optional; + +/** Reads/writes _zvec_enc.json inside a collection directory. Atomic on write. */ +public final class SidecarMetadata { + + public static final String FILENAME = "_zvec_enc.json"; + + private SidecarMetadata() {} + + public static Optional read(Path collectionDir) { + Path file = collectionDir.resolve(FILENAME); + if (!Files.exists(file)) { + return Optional.empty(); + } + String text; + try { + text = Files.readString(file, StandardCharsets.UTF_8); + } catch (IOException e) { + throw new EncryptionMetadataIOException("read sidecar failed: " + file, e); + } + return Optional.of(SidecarJson.read(text)); + } + + public static void write(Path collectionDir, EncryptionMetadata meta) { + String text = SidecarJson.write(meta); + Path file = collectionDir.resolve(FILENAME); + Path tmp = collectionDir.resolve(FILENAME + ".tmp." + java.util.UUID.randomUUID()); + try { + Files.createDirectories(collectionDir); + Files.writeString(tmp, text, StandardCharsets.UTF_8); + try { + Files.move(tmp, file, StandardCopyOption.ATOMIC_MOVE, StandardCopyOption.REPLACE_EXISTING); + } catch (java.nio.file.AtomicMoveNotSupportedException e) { + Files.move(tmp, file, StandardCopyOption.REPLACE_EXISTING); + } + } catch (IOException e) { + try { Files.deleteIfExists(tmp); } catch (IOException ignored) {} + throw new EncryptionMetadataIOException("write sidecar failed: " + file, e); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SingletonKeyProvider.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SingletonKeyProvider.java new file mode 100644 index 000000000..4a6cc6e9e --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/SingletonKeyProvider.java @@ -0,0 +1,29 @@ +package org.zvec.crypto; + +import java.util.Objects; + +/** Internal: wraps a single (keyId, key) pair as a KeyProvider. Used by the static-key sugar. */ +public final class SingletonKeyProvider implements KeyProvider { + private final String keyId; + private final byte[] key; + + public SingletonKeyProvider(String keyId, byte[] key) { + Objects.requireNonNull(keyId, "keyId"); + Objects.requireNonNull(key, "key"); + if (keyId.isEmpty()) { + throw new IllegalArgumentException("keyId must not be empty"); + } + if (key.length != AesGcm256.KEY_LEN) { + throw new IllegalArgumentException("key must be " + AesGcm256.KEY_LEN + " bytes, got " + key.length); + } + this.keyId = keyId; + this.key = key.clone(); + } + + String keyId() { return keyId; } + + @Override + public byte[] resolve(String requested) { + return keyId.equals(requested) ? key.clone() : null; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/UnsupportedFieldTypeException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/UnsupportedFieldTypeException.java new file mode 100644 index 000000000..99783f267 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/crypto/UnsupportedFieldTypeException.java @@ -0,0 +1,9 @@ +package org.zvec.crypto; + +/** Thrown when encryption is requested for a field type that is not supported. */ +public final class UnsupportedFieldTypeException extends EncryptionConfigException { + + public UnsupportedFieldTypeException(String message) { + super(message); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/HnswDefaults.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/HnswDefaults.java new file mode 100644 index 000000000..8fc85e7fb --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/HnswDefaults.java @@ -0,0 +1,135 @@ +package org.zvec.internal; + +import java.util.Objects; +import org.zvec.HnswIndexParams; +import org.zvec.HnswQueryParams; +import org.zvec.TuningProfile; +import org.zvec.VectorQuery; +import org.zvec.VectorSchema; + +public final class HnswDefaults { + private HnswDefaults() {} + + public static HnswIndexParams resolveIndexParams(VectorSchema schema) { + Objects.requireNonNull(schema, "schema"); + HnswIndexParams rawParams = schema.hnswIndexParams(); + if (rawParams != null) { + return rawParams; + } + + TuningProfile profile = + schema.tuningProfile() == null ? TuningProfile.BALANCED : schema.tuningProfile(); + Long expectedDocCount = schema.expectedDocCount(); + if (expectedDocCount == null) { + return indexParamsFor(SizeBucket.SMALL, profile); + } + + return indexParamsFor(bucket(expectedDocCount), profile); + } + + public static HnswQueryParams resolveQueryParams(VectorSchema schema, VectorQuery query) { + Objects.requireNonNull(schema, "schema"); + Objects.requireNonNull(query, "query"); + + HnswQueryParams rawParams = query.hnswQueryParams(); + if (rawParams != null) { + return rawParams; + } + + TuningProfile profile = + query.tuningProfile() != null + ? query.tuningProfile() + : schema.tuningProfile() != null ? schema.tuningProfile() : TuningProfile.BALANCED; + Long expectedDocCount = schema.expectedDocCount(); + int ef = queryEfFor(expectedDocCount == null ? SizeBucket.SMALL : bucket(expectedDocCount), profile); + return new HnswQueryParams(ef, 0.0f, false, false); + } + + private static HnswIndexParams indexParamsFor(SizeBucket bucket, TuningProfile profile) { + switch (bucket) { + case SMALL: + switch (profile) { + case FAST: + return new HnswIndexParams(12, 120); + case BALANCED: + return new HnswIndexParams(16, 200); + case ACCURATE: + return new HnswIndexParams(24, 300); + } + break; + case MEDIUM: + switch (profile) { + case FAST: + return new HnswIndexParams(16, 200); + case BALANCED: + return new HnswIndexParams(24, 300); + case ACCURATE: + return new HnswIndexParams(32, 400); + } + break; + case LARGE: + switch (profile) { + case FAST: + return new HnswIndexParams(16, 240); + case BALANCED: + return new HnswIndexParams(32, 400); + case ACCURATE: + return new HnswIndexParams(40, 500); + } + break; + } + throw new IllegalStateException("Unhandled size bucket/profile: " + bucket + "/" + profile); + } + + private static int queryEfFor(SizeBucket bucket, TuningProfile profile) { + switch (bucket) { + case SMALL: + switch (profile) { + case FAST: + return 32; + case BALANCED: + return 64; + case ACCURATE: + return 96; + } + break; + case MEDIUM: + switch (profile) { + case FAST: + return 48; + case BALANCED: + return 96; + case ACCURATE: + return 128; + } + break; + case LARGE: + switch (profile) { + case FAST: + return 64; + case BALANCED: + return 128; + case ACCURATE: + return 192; + } + break; + } + throw new IllegalStateException("Unhandled size bucket/profile: " + bucket + "/" + profile); + } + + private static SizeBucket bucket(long expectedDocCount) { + if (expectedDocCount <= 100_000L) { + return SizeBucket.SMALL; + } + if (expectedDocCount <= 1_000_000L) { + return SizeBucket.MEDIUM; + } + return SizeBucket.LARGE; + } + + private enum SizeBucket { + SMALL, + MEDIUM, + LARGE + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackend.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackend.java new file mode 100644 index 000000000..572424f7a --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackend.java @@ -0,0 +1,32 @@ +package org.zvec.internal; + +import java.util.List; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.VectorQuery; + +public interface NativeBackend { + String id(); + + String version(); + + void ensureInitialized(); + + NativeOpenResult open(String path); + + NativeOpenResult createAndOpen(String path, CollectionSchema schema); + + void close(NativeHandle handle); + + void flush(NativeHandle handle); + + CollectionSchema readSchema(NativeHandle handle); + + int insert(NativeHandle handle, CollectionSchema schema, List docs); + + List query( + NativeHandle handle, + CollectionSchema querySchema, + CollectionSchema resultSchema, + VectorQuery query); +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackendProvider.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackendProvider.java new file mode 100644 index 000000000..38b10d83e --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackendProvider.java @@ -0,0 +1,6 @@ +package org.zvec.internal; + +@FunctionalInterface +public interface NativeBackendProvider { + NativeBackend create(); +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackends.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackends.java new file mode 100644 index 000000000..b15da786e --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeBackends.java @@ -0,0 +1,72 @@ +package org.zvec.internal; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.ServiceLoader; +import java.util.stream.Collectors; + +public final class NativeBackends { + public static final String BACKEND_PROPERTY = "org.zvec.backend"; + + private NativeBackends() {} + + public static NativeBackend backend() { + return Holder.BACKEND; + } + + static NativeBackend resolve(Iterable providers, String requestedId) { + Objects.requireNonNull(providers, "providers"); + + List backends = new ArrayList<>(); + for (NativeBackendProvider provider : providers) { + NativeBackend backend = Objects.requireNonNull(provider.create(), "backend"); + backends.add(backend); + } + + if (backends.isEmpty()) { + throw new IllegalStateException( + "No zvec native backend found. Add exactly one backend dependency: zvec-java-jni or zvec-java-ffm."); + } + + String requested = normalize(requestedId); + if (requested != null) { + for (NativeBackend backend : backends) { + if (backend.id().equals(requested)) { + return backend; + } + } + throw new IllegalStateException( + "Requested zvec native backend '" + + requested + + "' was not found. Available backends: " + + backendIds(backends) + + "."); + } + + if (backends.size() == 1) { + return backends.get(0); + } + + throw new IllegalStateException( + "Multiple zvec native backends found: " + + backendIds(backends) + + ". Set -Dorg.zvec.backend=jni or -Dorg.zvec.backend=ffm."); + } + + private static String normalize(String requestedId) { + if (requestedId == null || requestedId.isBlank()) { + return null; + } + return requestedId.trim(); + } + + private static String backendIds(List backends) { + return backends.stream().map(NativeBackend::id).collect(Collectors.joining(", ")); + } + + private static final class Holder { + private static final NativeBackend BACKEND = + resolve(ServiceLoader.load(NativeBackendProvider.class), System.getProperty(BACKEND_PROPERTY)); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeHandle.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeHandle.java new file mode 100644 index 000000000..c008bc174 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeHandle.java @@ -0,0 +1,3 @@ +package org.zvec.internal; + +public interface NativeHandle {} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeOpenResult.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeOpenResult.java new file mode 100644 index 000000000..916ed7a31 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/NativeOpenResult.java @@ -0,0 +1,44 @@ +package org.zvec.internal; + +import java.util.Objects; +import org.zvec.CollectionSchema; + +public final class NativeOpenResult { + private final NativeHandle handle; + private final CollectionSchema querySchema; + + public NativeOpenResult(NativeHandle handle, CollectionSchema querySchema) { + this.handle = Objects.requireNonNull(handle, "handle"); + this.querySchema = Objects.requireNonNull(querySchema, "querySchema"); + } + + public NativeHandle handle() { + return handle; + } + + public CollectionSchema querySchema() { + return querySchema; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof NativeOpenResult)) { + return false; + } + NativeOpenResult other = (NativeOpenResult) obj; + return handle.equals(other.handle) && querySchema.equals(other.querySchema); + } + + @Override + public int hashCode() { + return Objects.hash(handle, querySchema); + } + + @Override + public String toString() { + return "NativeOpenResult[handle=" + handle + ", querySchema=" + querySchema + "]"; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/SchemaMetadataStore.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/SchemaMetadataStore.java new file mode 100644 index 000000000..3cfaa9b16 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/SchemaMetadataStore.java @@ -0,0 +1,132 @@ +package org.zvec.internal; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; +import org.zvec.CollectionSchema; +import org.zvec.TuningProfile; +import org.zvec.VectorSchema; + +public final class SchemaMetadataStore { + private static final String VERSION_KEY = "version"; + private static final String VERSION = "1"; + private static final String FILE_NAME = ".zvec-java-schema.properties"; + + private SchemaMetadataStore() {} + + public static void write(String collectionPath, CollectionSchema schema) { + Path metadataPath = metadataPath(collectionPath); + Properties properties = new Properties(); + properties.setProperty(VERSION_KEY, VERSION); + + for (VectorSchema vector : schema.vectors()) { + String prefix = vectorPrefix(vector.name()); + properties.setProperty( + prefix + "rawIndexParamsExplicit", Boolean.toString(vector.hnswIndexParams() != null)); + if (vector.tuningProfile() != null) { + properties.setProperty(prefix + "tuningProfile", vector.tuningProfile().name()); + } + if (vector.expectedDocCount() != null) { + properties.setProperty(prefix + "expectedDocCount", Long.toString(vector.expectedDocCount())); + } + } + + try { + Files.createDirectories(metadataPath.getParent()); + try (OutputStream output = Files.newOutputStream(metadataPath)) { + properties.store(output, "zvec-java schema metadata"); + } + } catch (IOException e) { + throw new IllegalStateException("Failed to write Java schema metadata", e); + } + } + + public static CollectionSchema merge(String collectionPath, CollectionSchema nativeSchema) { + Path metadataPath = metadataPath(collectionPath); + if (!Files.isRegularFile(metadataPath)) { + return nativeSchema; + } + + Properties properties = new Properties(); + try (InputStream input = Files.newInputStream(metadataPath)) { + properties.load(input); + if (!VERSION.equals(properties.getProperty(VERSION_KEY))) { + return nativeSchema; + } + + List vectors = new ArrayList<>(nativeSchema.vectors().size()); + for (VectorSchema vector : nativeSchema.vectors()) { + vectors.add(mergeVector(properties, vector)); + } + return new CollectionSchema(nativeSchema.name(), nativeSchema.fields(), vectors); + } catch (IOException | RuntimeException ignored) { + return nativeSchema; + } + } + + private static VectorSchema mergeVector(Properties properties, VectorSchema vector) { + String prefix = vectorPrefix(vector.name()); + String rawExplicitValue = properties.getProperty(prefix + "rawIndexParamsExplicit"); + if (rawExplicitValue == null) { + throw new IllegalStateException( + "Missing raw state flag in Java schema metadata: " + prefix + "rawIndexParamsExplicit"); + } + boolean rawIndexParamsExplicit = + parseBoolean(rawExplicitValue, prefix + "rawIndexParamsExplicit"); + + TuningProfile tuningProfile = vector.tuningProfile(); + String profileValue = properties.getProperty(prefix + "tuningProfile"); + if (profileValue != null) { + try { + tuningProfile = TuningProfile.valueOf(profileValue); + } catch (IllegalArgumentException e) { + throw new IllegalStateException("Invalid tuning profile in Java schema metadata: " + profileValue, e); + } + } + + Long expectedDocCount = vector.expectedDocCount(); + String expectedDocCountValue = properties.getProperty(prefix + "expectedDocCount"); + if (expectedDocCountValue != null) { + expectedDocCount = parseLong(expectedDocCountValue, prefix + "expectedDocCount"); + } + + return new VectorSchema( + vector.name(), + vector.dataType(), + vector.dimension(), + rawIndexParamsExplicit ? vector.hnswIndexParams() : null, + tuningProfile, + expectedDocCount); + } + + private static long parseLong(String value, String key) { + try { + return Long.parseLong(value); + } catch (NumberFormatException e) { + throw new IllegalStateException("Invalid long in Java schema metadata: " + key, e); + } + } + + private static boolean parseBoolean(String value, String key) { + if ("true".equalsIgnoreCase(value)) { + return true; + } + if ("false".equalsIgnoreCase(value)) { + return false; + } + throw new IllegalStateException("Invalid boolean in Java schema metadata: " + key); + } + + private static String vectorPrefix(String vectorName) { + return "vector." + vectorName + "."; + } + + private static Path metadataPath(String collectionPath) { + return Path.of(collectionPath).resolve(FILE_NAME); + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/ZvecException.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/ZvecException.java new file mode 100644 index 000000000..876ebbd2e --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/internal/ZvecException.java @@ -0,0 +1,14 @@ +package org.zvec.internal; + +public final class ZvecException extends RuntimeException { + private final int code; + + public ZvecException(int code, String message) { + super(message); + this.code = code; + } + + public int code() { + return code; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/CollectionConcurrentStressMain.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/CollectionConcurrentStressMain.java new file mode 100644 index 000000000..d57761eec --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/CollectionConcurrentStressMain.java @@ -0,0 +1,492 @@ +package org.zvec.perf; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import org.zvec.Collection; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.HnswQueryParams; +import org.zvec.VectorQuery; +import org.zvec.Zvec; + +public final class CollectionConcurrentStressMain { + private CollectionConcurrentStressMain() {} + + public static void main(String[] args) throws Exception { + StressOptions options = StressOptions.parse(args); + Path runDir = prepareRunDirectory(options.workDir(), options.docCount()); + + System.out.println("CONCURRENT_STRESS_CONFIG " + formatConfig(options, runDir)); + + CollectionSchema schema = + PerfData.schema("perf_docs_concurrent", options.dimension(), options.hnswIndexParams()); + try (Collection collection = Zvec.createAndOpen(runDir.toString(), schema)) { + insertInitialDataset(collection, options); + collection.flush(); + warmupQueries(collection, options, options.warmupQueries()); + + QueryOnlyMetrics queryOnly = runQueryOnly(collection, options); + MixedMetrics mixed = runMixed(collection, options); + + printSummary(options, queryOnly, mixed, runDir); + } + } + + private static void insertInitialDataset(Collection collection, StressOptions options) { + for (int batchStart = 0; batchStart < options.docCount(); batchStart += options.batchSize()) { + int batchSize = Math.min(options.batchSize(), options.docCount() - batchStart); + List docs = PerfData.docs(batchStart, batchSize, options.dimension(), options.seed()); + int inserted = collection.insert(docs); + if (inserted != batchSize) { + throw new IllegalStateException( + "Inserted count mismatch: expected " + batchSize + ", got " + inserted); + } + } + } + + private static void warmupQueries(Collection collection, StressOptions options, int queryCount) { + Random random = new Random(options.seed() ^ 0x13579BDFL); + for (int i = 0; i < queryCount; i++) { + int docIndex = random.nextInt(options.docCount()); + PerfData.VectorSample sample = + PerfData.querySample(docIndex, options.dimension(), options.seed(), options.topK()); + collection.query(buildQuery(sample, options)); + } + } + + private static QueryOnlyMetrics runQueryOnly(Collection collection, StressOptions options) + throws Exception { + int threads = options.concurrentQueryThreads(); + int queriesPerThread = options.concurrentQueryCount(); + int totalQueries = threads * queriesPerThread; + long[] latencies = new long[totalQueries]; + AtomicInteger failureCount = new AtomicInteger(); + AtomicInteger missCount = new AtomicInteger(); + + long startedAt = System.nanoTime(); + runConcurrent( + threads, + threadIndex -> { + Random random = new Random(options.seed() ^ 0x2468ACE0L ^ threadIndex); + int offset = threadIndex * queriesPerThread; + for (int i = 0; i < queriesPerThread; i++) { + int docIndex = random.nextInt(options.docCount()); + PerfData.VectorSample sample = + PerfData.querySample(docIndex, options.dimension(), options.seed(), options.topK()); + long queryStartedAt = System.nanoTime(); + try { + List results = collection.query(buildQuery(sample, options)); + latencies[offset + i] = System.nanoTime() - queryStartedAt; + if (!CollectionStressMain.hasExpectedHit(sample.expectedId(), results)) { + missCount.incrementAndGet(); + } + } catch (RuntimeException e) { + failureCount.incrementAndGet(); + } + } + }); + long elapsedNanos = System.nanoTime() - startedAt; + + return new QueryOnlyMetrics( + totalQueries, + failureCount.get(), + missCount.get(), + latencyStats(latencies), + elapsedNanos); + } + + private static MixedMetrics runMixed(Collection collection, StressOptions options) throws Exception { + int threads = options.concurrentMixedThreads(); + int rounds = options.concurrentMixedRounds(); + int insertBatchSize = options.concurrentMixedInsertBatchSize(); + int queriesPerRound = options.concurrentMixedQueriesPerRound(); + + long[] insertLatencies = new long[threads * rounds]; + long[] queryLatencies = new long[threads * rounds * queriesPerRound]; + AtomicInteger insertFailures = new AtomicInteger(); + AtomicInteger queryFailures = new AtomicInteger(); + AtomicInteger missCount = new AtomicInteger(); + AtomicInteger nextDocIndex = new AtomicInteger(options.docCount()); + + long startedAt = System.nanoTime(); + runConcurrent( + threads, + threadIndex -> { + Random random = new Random(options.seed() ^ 0x55AA55AAL ^ threadIndex); + int insertOffset = threadIndex * rounds; + int queryOffset = threadIndex * rounds * queriesPerRound; + + for (int round = 0; round < rounds; round++) { + int startDocIndex = nextDocIndex.getAndAdd(insertBatchSize); + long insertStartedAt = System.nanoTime(); + try { + List docs = + PerfData.docs(startDocIndex, insertBatchSize, options.dimension(), options.seed()); + int inserted = collection.insert(docs); + if (inserted != insertBatchSize) { + insertFailures.incrementAndGet(); + } + insertLatencies[insertOffset + round] = System.nanoTime() - insertStartedAt; + } catch (RuntimeException e) { + insertFailures.incrementAndGet(); + } + + for (int q = 0; q < queriesPerRound; q++) { + int docIndex = random.nextInt(options.docCount()); + PerfData.VectorSample sample = + PerfData.querySample(docIndex, options.dimension(), options.seed(), options.topK()); + long queryStartedAt = System.nanoTime(); + try { + List results = collection.query(buildQuery(sample, options)); + queryLatencies[queryOffset + round * queriesPerRound + q] = + System.nanoTime() - queryStartedAt; + if (!CollectionStressMain.hasExpectedHit(sample.expectedId(), results)) { + missCount.incrementAndGet(); + } + } catch (RuntimeException e) { + queryFailures.incrementAndGet(); + } + } + } + }); + long elapsedNanos = System.nanoTime() - startedAt; + + return new MixedMetrics( + threads * rounds * insertBatchSize, + insertFailures.get(), + latencyStats(insertLatencies), + threads * rounds * queriesPerRound, + queryFailures.get(), + missCount.get(), + latencyStats(queryLatencies), + elapsedNanos); + } + + private static void runConcurrent(int threads, ThrowingIntConsumer task) throws Exception { + ExecutorService executor = Executors.newFixedThreadPool(threads); + CountDownLatch start = new CountDownLatch(1); + List> futures = new ArrayList<>(threads); + try { + for (int threadIndex = 0; threadIndex < threads; threadIndex++) { + int index = threadIndex; + futures.add( + executor.submit( + () -> { + start.await(); + task.accept(index); + return null; + })); + } + start.countDown(); + for (Future future : futures) { + future.get(); + } + } finally { + executor.shutdownNow(); + } + } + + private static LatencyStats latencyStats(long[] samples) { + int successCount = 0; + for (long sample : samples) { + if (sample > 0L) { + successCount++; + } + } + if (successCount == 0) { + return null; + } + if (successCount == samples.length) { + return LatencyStats.fromNanos(samples); + } + long[] compacted = new long[successCount]; + int index = 0; + for (long sample : samples) { + if (sample > 0L) { + compacted[index++] = sample; + } + } + return LatencyStats.fromNanos(compacted); + } + + private static void printSummary( + StressOptions options, QueryOnlyMetrics queryOnly, MixedMetrics mixed, Path runDir) { + System.out.printf( + Locale.ROOT, + "QUERY_ONLY_SUMMARY threads=%d queries=%d query_failures=%d miss_count=%d recall=%.4f seconds=%.3f queries_per_sec=%.1f p50_us=%s p95_us=%s p99_us=%s%n", + options.concurrentQueryThreads(), + queryOnly.queryCount(), + queryOnly.failureCount(), + queryOnly.missCount(), + queryOnly.recall(), + queryOnly.elapsedSeconds(), + queryOnly.queriesPerSecond(), + formatLatency(queryOnly.latencyStats(), LatencyStatField.P50), + formatLatency(queryOnly.latencyStats(), LatencyStatField.P95), + formatLatency(queryOnly.latencyStats(), LatencyStatField.P99)); + System.out.printf( + Locale.ROOT, + "MIXED_SUMMARY threads=%d rounds=%d inserted_docs=%d insert_failures=%d insert_docs_per_sec=%.1f insert_p50_us=%s insert_p95_us=%s insert_p99_us=%s queries=%d query_failures=%d miss_count=%d recall=%.4f queries_per_sec=%.1f query_p50_us=%s query_p95_us=%s query_p99_us=%s%n", + options.concurrentMixedThreads(), + options.concurrentMixedRounds(), + mixed.insertedDocs(), + mixed.insertFailures(), + mixed.insertDocsPerSecond(), + formatLatency(mixed.insertLatencyStats(), LatencyStatField.P50), + formatLatency(mixed.insertLatencyStats(), LatencyStatField.P95), + formatLatency(mixed.insertLatencyStats(), LatencyStatField.P99), + mixed.queryCount(), + mixed.queryFailures(), + mixed.missCount(), + mixed.recall(), + mixed.queriesPerSecond(), + formatLatency(mixed.queryLatencyStats(), LatencyStatField.P50), + formatLatency(mixed.queryLatencyStats(), LatencyStatField.P95), + formatLatency(mixed.queryLatencyStats(), LatencyStatField.P99)); + System.out.println("ARTIFACT_DIR " + runDir); + System.out.println( + "SUGGESTED_RUNS_FROM_MODULE_DIR " + + "\"mvn -q -DskipTests compile exec:java@run-concurrent-stress -Dzvec.stress.args='--docs 100000 --concurrent-query-threads 4 --concurrent-query-count 250'\""); + } + + private static String formatConfig(StressOptions options, Path runDir) { + return "docs=" + + options.docCount() + + " batch_size=" + + options.batchSize() + + " dimension=" + + options.dimension() + + " top_k=" + + options.topK() + + " warmup_queries=" + + options.warmupQueries() + + " concurrent_query_threads=" + + options.concurrentQueryThreads() + + " concurrent_query_count=" + + options.concurrentQueryCount() + + " concurrent_mixed_threads=" + + options.concurrentMixedThreads() + + " concurrent_mixed_rounds=" + + options.concurrentMixedRounds() + + " concurrent_mixed_insert_batch_size=" + + options.concurrentMixedInsertBatchSize() + + " concurrent_mixed_queries_per_round=" + + options.concurrentMixedQueriesPerRound() + + " hnsw_m=" + + formatOptionalHnswM(options) + + " hnsw_ef_construction=" + + formatOptionalHnswEfConstruction(options) + + " hnsw_ef=" + + formatOptionalHnswEf(options) + + " seed=" + + options.seed() + + " run_dir=" + + runDir; + } + + private static VectorQuery buildQuery(PerfData.VectorSample sample, StressOptions options) { + VectorQuery query = + VectorQuery.of("embedding", sample.vector()).topK(sample.topK()).outputFields("title"); + HnswQueryParams hnswQueryParams = options.hnswQueryParams(); + if (hnswQueryParams != null) { + query.hnsw(hnswQueryParams); + } + return query; + } + + private static Path prepareRunDirectory(Path baseDir, int docCount) throws IOException { + Files.createDirectories(baseDir); + String runName = "concurrent-docs-" + docCount + "-" + Instant.now().toEpochMilli(); + return baseDir.resolve(runName); + } + + private static String formatOptionalHnswM(StressOptions options) { + return options.hnswIndexParams() == null + ? "default" + : Integer.toString(options.hnswIndexParams().m()); + } + + private static String formatOptionalHnswEfConstruction(StressOptions options) { + return options.hnswIndexParams() == null + ? "default" + : Integer.toString(options.hnswIndexParams().efConstruction()); + } + + private static String formatOptionalHnswEf(StressOptions options) { + return options.hnswQueryParams() == null + ? "default" + : Integer.toString(options.hnswQueryParams().ef()); + } + + private static String formatLatency(LatencyStats stats, LatencyStatField field) { + if (stats == null) { + return "n/a"; + } + double value; + switch (field) { + case P50: + value = stats.p50Micros(); + break; + case P95: + value = stats.p95Micros(); + break; + case P99: + value = stats.p99Micros(); + break; + default: + throw new IllegalStateException("Unhandled latency field: " + field); + } + return String.format(Locale.ROOT, "%.1f", value); + } + + private enum LatencyStatField { + P50, + P95, + P99 + } + + @FunctionalInterface + private interface ThrowingIntConsumer { + void accept(int value) throws Exception; + } + + private static final class QueryOnlyMetrics { + private final int queryCount; + private final int failureCount; + private final int missCount; + private final LatencyStats latencyStats; + private final long elapsedNanos; + + private QueryOnlyMetrics( + int queryCount, + int failureCount, + int missCount, + LatencyStats latencyStats, + long elapsedNanos) { + this.queryCount = queryCount; + this.failureCount = failureCount; + this.missCount = missCount; + this.latencyStats = latencyStats; + this.elapsedNanos = elapsedNanos; + } + + private int queryCount() { + return queryCount; + } + + private int failureCount() { + return failureCount; + } + + private int missCount() { + return missCount; + } + + private LatencyStats latencyStats() { + return latencyStats; + } + + private double recall() { + if (queryCount == 0) { + return 1.0; + } + return (queryCount - failureCount - missCount) / (double) queryCount; + } + + private double elapsedSeconds() { + return elapsedNanos / 1_000_000_000.0; + } + + private double queriesPerSecond() { + return queryCount / elapsedSeconds(); + } + } + + private static final class MixedMetrics { + private final int insertedDocs; + private final int insertFailures; + private final LatencyStats insertLatencyStats; + private final int queryCount; + private final int queryFailures; + private final int missCount; + private final LatencyStats queryLatencyStats; + private final long elapsedNanos; + + private MixedMetrics( + int insertedDocs, + int insertFailures, + LatencyStats insertLatencyStats, + int queryCount, + int queryFailures, + int missCount, + LatencyStats queryLatencyStats, + long elapsedNanos) { + this.insertedDocs = insertedDocs; + this.insertFailures = insertFailures; + this.insertLatencyStats = insertLatencyStats; + this.queryCount = queryCount; + this.queryFailures = queryFailures; + this.missCount = missCount; + this.queryLatencyStats = queryLatencyStats; + this.elapsedNanos = elapsedNanos; + } + + private int insertedDocs() { + return insertedDocs; + } + + private int insertFailures() { + return insertFailures; + } + + private LatencyStats insertLatencyStats() { + return insertLatencyStats; + } + + private int queryCount() { + return queryCount; + } + + private int queryFailures() { + return queryFailures; + } + + private int missCount() { + return missCount; + } + + private LatencyStats queryLatencyStats() { + return queryLatencyStats; + } + + private double recall() { + if (queryCount == 0) { + return 1.0; + } + return (queryCount - queryFailures - missCount) / (double) queryCount; + } + + private double elapsedSeconds() { + return elapsedNanos / 1_000_000_000.0; + } + + private double insertDocsPerSecond() { + return insertedDocs / elapsedSeconds(); + } + + private double queriesPerSecond() { + return queryCount / elapsedSeconds(); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/CollectionStressMain.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/CollectionStressMain.java new file mode 100644 index 000000000..78eef240b --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/CollectionStressMain.java @@ -0,0 +1,438 @@ +package org.zvec.perf; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.List; +import java.util.Locale; +import java.util.Random; +import org.zvec.Collection; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.HnswQueryParams; +import org.zvec.VectorQuery; +import org.zvec.Zvec; + +public final class CollectionStressMain { + private CollectionStressMain() {} + + public static void main(String[] args) throws Exception { + StressOptions options = StressOptions.parse(args); + Path runDir = prepareRunDirectory(options.workDir(), options.docCount()); + + System.out.println("STRESS_CONFIG " + formatConfig(options, runDir)); + + CollectionSchema schema = PerfData.schema("perf_docs", options.dimension(), options.hnswIndexParams()); + MemorySnapshot memoryBefore = captureMemory(); + + int nextDocIndex = options.docCount(); + try (Collection collection = Zvec.createAndOpen(runDir.toString(), schema)) { + InsertMetrics insertMetrics = insertInitialDataset(collection, options); + collection.flush(); + MemorySnapshot memoryAfterInsert = captureMemory(); + + warmupQueries(collection, options); + QueryMetrics queryMetrics = runQueryPhase(collection, options, options.docCount()); + MemorySnapshot memoryAfterQueries = captureMemory(); + + ReliabilityMetrics reliability = runSteadyState(collection, options, nextDocIndex); + MemorySnapshot memoryAfterSteadyState = captureMemory(); + + printSummary( + options, + insertMetrics, + queryMetrics, + memoryBefore, + memoryAfterInsert, + memoryAfterQueries, + memoryAfterSteadyState, + reliability, + runDir); + } + } + + private static InsertMetrics insertInitialDataset(Collection collection, StressOptions options) { + long startedAt = System.nanoTime(); + for (int batchStart = 0; batchStart < options.docCount(); batchStart += options.batchSize()) { + int batchSize = Math.min(options.batchSize(), options.docCount() - batchStart); + List docs = PerfData.docs(batchStart, batchSize, options.dimension(), options.seed()); + int inserted = collection.insert(docs); + if (inserted != batchSize) { + throw new IllegalStateException( + "Inserted count mismatch: expected " + batchSize + ", got " + inserted); + } + } + long elapsedNanos = System.nanoTime() - startedAt; + return new InsertMetrics(options.docCount(), elapsedNanos); + } + + private static void warmupQueries(Collection collection, StressOptions options) { + Random random = new Random(options.seed() ^ 0xA5A5A5A5L); + for (int i = 0; i < options.warmupQueries(); i++) { + int docIndex = random.nextInt(options.docCount()); + PerfData.VectorSample sample = + PerfData.querySample(docIndex, options.dimension(), options.seed(), options.topK()); + collection.query(buildQuery(sample, options)); + } + } + + private static QueryMetrics runQueryPhase( + Collection collection, StressOptions options, int maxDocIndexExclusive) { + Random random = new Random(options.seed() ^ 0x5A5A5A5AL); + long[] latencies = new long[options.queryCount()]; + int missCount = 0; + for (int i = 0; i < options.queryCount(); i++) { + int docIndex = random.nextInt(maxDocIndexExclusive); + PerfData.VectorSample sample = + PerfData.querySample(docIndex, options.dimension(), options.seed(), options.topK()); + long startedAt = System.nanoTime(); + List results = + collection.query(buildQuery(sample, options)); + latencies[i] = System.nanoTime() - startedAt; + if (!hasExpectedHit(sample.expectedId(), results)) { + missCount++; + } + } + return new QueryMetrics(LatencyStats.fromNanos(latencies), missCount); + } + + private static ReliabilityMetrics runSteadyState( + Collection collection, StressOptions options, int nextDocIndex) { + long[] steadyLatencies = + new long[Math.max(1, options.steadyStateRounds() * options.steadyQueriesPerRound())]; + int latencyIndex = 0; + int insertFailures = 0; + int queryFailures = 0; + + for (int round = 0; round < options.steadyStateRounds(); round++) { + try { + List docs = + PerfData.docs( + nextDocIndex, options.steadyInsertBatchSize(), options.dimension(), options.seed()); + int inserted = collection.insert(docs); + if (inserted != options.steadyInsertBatchSize()) { + insertFailures++; + } + nextDocIndex += options.steadyInsertBatchSize(); + } catch (RuntimeException e) { + insertFailures++; + } + + for (int q = 0; q < options.steadyQueriesPerRound(); q++) { + int docIndex = nextDocIndex - 1 - q; + if (docIndex < 0) { + docIndex = 0; + } + try { + PerfData.VectorSample sample = + PerfData.querySample(docIndex, options.dimension(), options.seed(), options.topK()); + long startedAt = System.nanoTime(); + List results = + collection.query(buildQuery(sample, options)); + steadyLatencies[latencyIndex++] = System.nanoTime() - startedAt; + if (!hasExpectedHit(sample.expectedId(), results)) { + queryFailures++; + } + } catch (RuntimeException e) { + queryFailures++; + } + } + } + + long[] samples = + latencyIndex == steadyLatencies.length ? steadyLatencies : java.util.Arrays.copyOf(steadyLatencies, latencyIndex); + LatencyStats steadyStats = + samples.length == 0 ? null : LatencyStats.fromNanos(samples); + return new ReliabilityMetrics(options.steadyStateRounds(), insertFailures, queryFailures, steadyStats); + } + + private static void printSummary( + StressOptions options, + InsertMetrics insertMetrics, + QueryMetrics queryMetrics, + MemorySnapshot memoryBefore, + MemorySnapshot memoryAfterInsert, + MemorySnapshot memoryAfterQueries, + MemorySnapshot memoryAfterSteadyState, + ReliabilityMetrics reliability, + Path runDir) { + System.out.printf( + Locale.ROOT, + "INSERT_SUMMARY docs=%d seconds=%.3f docs_per_sec=%.0f%n", + insertMetrics.insertedDocs(), + insertMetrics.elapsedSeconds(), + insertMetrics.docsPerSecond()); + System.out.printf( + Locale.ROOT, + "QUERY_SUMMARY count=%d miss_count=%d recall=%.4f mean_us=%.1f p50_us=%.1f p95_us=%.1f p99_us=%.1f max_us=%.1f%n", + queryMetrics.latencyStats().count(), + queryMetrics.missCount(), + queryMetrics.recall(), + queryMetrics.latencyStats().meanMicros(), + queryMetrics.latencyStats().p50Micros(), + queryMetrics.latencyStats().p95Micros(), + queryMetrics.latencyStats().p99Micros(), + queryMetrics.latencyStats().maxMicros()); + System.out.printf( + Locale.ROOT, + "MEMORY_SUMMARY heap_before_mb=%.2f heap_after_insert_mb=%.2f heap_after_queries_mb=%.2f heap_after_steady_mb=%.2f heap_growth_insert_mb=%.2f heap_growth_total_mb=%.2f rss_before_mb=%s rss_after_insert_mb=%s rss_after_queries_mb=%s rss_after_steady_mb=%s rss_growth_insert_mb=%s rss_growth_total_mb=%s%n", + bytesToMiB(memoryBefore.heapBytes()), + bytesToMiB(memoryAfterInsert.heapBytes()), + bytesToMiB(memoryAfterQueries.heapBytes()), + bytesToMiB(memoryAfterSteadyState.heapBytes()), + bytesToMiB(memoryAfterInsert.heapBytes() - memoryBefore.heapBytes()), + bytesToMiB(memoryAfterSteadyState.heapBytes() - memoryBefore.heapBytes()), + formatOptionalMiB(memoryBefore.rssBytes()), + formatOptionalMiB(memoryAfterInsert.rssBytes()), + formatOptionalMiB(memoryAfterQueries.rssBytes()), + formatOptionalMiB(memoryAfterSteadyState.rssBytes()), + formatOptionalMiB(delta(memoryAfterInsert.rssBytes(), memoryBefore.rssBytes())), + formatOptionalMiB(delta(memoryAfterSteadyState.rssBytes(), memoryBefore.rssBytes()))); + if (reliability.steadyStats() != null) { + System.out.printf( + Locale.ROOT, + "STEADY_STATE rounds=%d insert_failures=%d query_failures=%d p50_us=%.1f p95_us=%.1f p99_us=%.1f%n", + reliability.rounds(), + reliability.insertFailures(), + reliability.queryFailures(), + reliability.steadyStats().p50Micros(), + reliability.steadyStats().p95Micros(), + reliability.steadyStats().p99Micros()); + } else { + System.out.printf( + Locale.ROOT, + "STEADY_STATE rounds=%d insert_failures=%d query_failures=%d%n", + reliability.rounds(), + reliability.insertFailures(), + reliability.queryFailures()); + } + System.out.println("ARTIFACT_DIR " + runDir); + System.out.println( + "SUGGESTED_RUNS_FROM_MODULE_DIR " + + "\"mvn -q -DskipTests compile exec:java@run-stress -Dzvec.stress.args='--docs 100000'\" " + + "\"mvn -q -DskipTests compile exec:java@run-stress -Dzvec.stress.args='--docs 1000000 --queries 5000 --steady-state-rounds 50'\""); + } + + private static Path prepareRunDirectory(Path baseDir, int docCount) throws IOException { + Files.createDirectories(baseDir); + String runName = "docs-" + docCount + "-" + Instant.now().toEpochMilli(); + return baseDir.resolve(runName); + } + + private static long usedHeapBytes() { + Runtime runtime = Runtime.getRuntime(); + runtime.gc(); + runtime.gc(); + return runtime.totalMemory() - runtime.freeMemory(); + } + + private static MemorySnapshot captureMemory() { + return new MemorySnapshot(usedHeapBytes(), readResidentSetBytes()); + } + + private static Long readResidentSetBytes() { + Process process = null; + try { + process = + new ProcessBuilder( + "ps", "-o", "rss=", "-p", Long.toString(ProcessHandle.current().pid())) + .redirectErrorStream(true) + .start(); + byte[] output = process.getInputStream().readAllBytes(); + int exitCode = process.waitFor(); + if (exitCode != 0) { + return null; + } + String value = new String(output, StandardCharsets.UTF_8).trim(); + if (value.isEmpty()) { + return null; + } + return Long.parseLong(value) * 1024L; + } catch (IOException | InterruptedException | NumberFormatException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + return null; + } finally { + if (process != null) { + process.destroy(); + } + } + } + + private static double bytesToMiB(long bytes) { + return bytes / (1024.0 * 1024.0); + } + + static boolean hasExpectedHit(String expectedId, List results) { + for (Doc result : results) { + if (expectedId.equals(result.id())) { + return true; + } + } + return false; + } + + private static String formatOptionalMiB(Long bytes) { + if (bytes == null) { + return "unavailable"; + } + return String.format(Locale.ROOT, "%.2f", bytesToMiB(bytes)); + } + + private static Long delta(Long end, Long start) { + if (end == null || start == null) { + return null; + } + return end - start; + } + + private static String formatConfig(StressOptions options, Path runDir) { + return "docs=" + + options.docCount() + + " queries=" + + options.queryCount() + + " batch_size=" + + options.batchSize() + + " dimension=" + + options.dimension() + + " top_k=" + + options.topK() + + " warmup_queries=" + + options.warmupQueries() + + " steady_state_rounds=" + + options.steadyStateRounds() + + " steady_insert_batch_size=" + + options.steadyInsertBatchSize() + + " steady_queries_per_round=" + + options.steadyQueriesPerRound() + + " hnsw_m=" + + formatOptionalHnswM(options) + + " hnsw_ef_construction=" + + formatOptionalHnswEfConstruction(options) + + " hnsw_ef=" + + formatOptionalHnswEf(options) + + " seed=" + + options.seed() + + " run_dir=" + + runDir; + } + + private static VectorQuery buildQuery(PerfData.VectorSample sample, StressOptions options) { + VectorQuery query = VectorQuery.of("embedding", sample.vector()).topK(sample.topK()).outputFields("title"); + HnswQueryParams hnswQueryParams = options.hnswQueryParams(); + if (hnswQueryParams != null) { + query.hnsw(hnswQueryParams); + } + return query; + } + + private static String formatOptionalHnswM(StressOptions options) { + return options.hnswIndexParams() == null ? "default" : Integer.toString(options.hnswIndexParams().m()); + } + + private static String formatOptionalHnswEfConstruction(StressOptions options) { + return options.hnswIndexParams() == null + ? "default" + : Integer.toString(options.hnswIndexParams().efConstruction()); + } + + private static String formatOptionalHnswEf(StressOptions options) { + return options.hnswQueryParams() == null ? "default" : Integer.toString(options.hnswQueryParams().ef()); + } + + private static final class InsertMetrics { + private final int insertedDocs; + private final long elapsedNanos; + + private InsertMetrics(int insertedDocs, long elapsedNanos) { + this.insertedDocs = insertedDocs; + this.elapsedNanos = elapsedNanos; + } + + private int insertedDocs() { + return insertedDocs; + } + + private double elapsedSeconds() { + return elapsedNanos / 1_000_000_000.0; + } + + private double docsPerSecond() { + return insertedDocs / elapsedSeconds(); + } + } + + private static final class MemorySnapshot { + private final long heapBytes; + private final Long rssBytes; + + private MemorySnapshot(long heapBytes, Long rssBytes) { + this.heapBytes = heapBytes; + this.rssBytes = rssBytes; + } + + private long heapBytes() { + return heapBytes; + } + + private Long rssBytes() { + return rssBytes; + } + } + + private static final class QueryMetrics { + private final LatencyStats latencyStats; + private final int missCount; + + private QueryMetrics(LatencyStats latencyStats, int missCount) { + this.latencyStats = latencyStats; + this.missCount = missCount; + } + + private LatencyStats latencyStats() { + return latencyStats; + } + + private int missCount() { + return missCount; + } + + private double recall() { + return (latencyStats.count() - missCount) / (double) latencyStats.count(); + } + } + + private static final class ReliabilityMetrics { + private final int rounds; + private final int insertFailures; + private final int queryFailures; + private final LatencyStats steadyStats; + + private ReliabilityMetrics( + int rounds, int insertFailures, int queryFailures, LatencyStats steadyStats) { + this.rounds = rounds; + this.insertFailures = insertFailures; + this.queryFailures = queryFailures; + this.steadyStats = steadyStats; + } + + private int rounds() { + return rounds; + } + + private int insertFailures() { + return insertFailures; + } + + private int queryFailures() { + return queryFailures; + } + + private LatencyStats steadyStats() { + return steadyStats; + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/LatencyStats.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/LatencyStats.java new file mode 100644 index 000000000..436ad9b21 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/LatencyStats.java @@ -0,0 +1,97 @@ +package org.zvec.perf; + +import java.util.Arrays; + +public final class LatencyStats { + private final int count; + private final double minMicros; + private final double p50Micros; + private final double p95Micros; + private final double p99Micros; + private final double maxMicros; + private final double meanMicros; + + private LatencyStats( + int count, + double minMicros, + double p50Micros, + double p95Micros, + double p99Micros, + double maxMicros, + double meanMicros) { + this.count = count; + this.minMicros = minMicros; + this.p50Micros = p50Micros; + this.p95Micros = p95Micros; + this.p99Micros = p99Micros; + this.maxMicros = maxMicros; + this.meanMicros = meanMicros; + } + + public static LatencyStats fromNanos(long[] samplesNanos) { + if (samplesNanos == null || samplesNanos.length == 0) { + throw new IllegalArgumentException("samplesNanos must not be empty"); + } + + long[] sorted = samplesNanos.clone(); + Arrays.sort(sorted); + + long sum = 0L; + for (long sample : sorted) { + sum += sample; + } + + return new LatencyStats( + sorted.length, + nanosToMicros(sorted[0]), + nanosToMicros(percentile(sorted, 0.50)), + nanosToMicros(percentile(sorted, 0.95)), + nanosToMicros(percentile(sorted, 0.99)), + nanosToMicros(sorted[sorted.length - 1]), + nanosToMicros((double) sum / sorted.length)); + } + + public int count() { + return count; + } + + public double minMicros() { + return minMicros; + } + + public double p50Micros() { + return p50Micros; + } + + public double p95Micros() { + return p95Micros; + } + + public double p99Micros() { + return p99Micros; + } + + public double maxMicros() { + return maxMicros; + } + + public double meanMicros() { + return meanMicros; + } + + private static long percentile(long[] sorted, double percentile) { + int index = (int) Math.ceil(percentile * sorted.length) - 1; + if (index < 0) { + index = 0; + } + return sorted[index]; + } + + private static double nanosToMicros(long nanos) { + return nanos / 1_000.0; + } + + private static double nanosToMicros(double nanos) { + return nanos / 1_000.0; + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/PerfData.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/PerfData.java new file mode 100644 index 000000000..8d3c91fb5 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/PerfData.java @@ -0,0 +1,106 @@ +package org.zvec.perf; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import org.zvec.CollectionSchema; +import org.zvec.DataType; +import org.zvec.Doc; +import org.zvec.FieldSchema; +import org.zvec.HnswIndexParams; +import org.zvec.VectorSchema; + +public final class PerfData { + private PerfData() {} + + public static CollectionSchema schema(String name, int dimension) { + return schema(name, dimension, null); + } + + public static CollectionSchema schema(String name, int dimension, HnswIndexParams hnswIndexParams) { + VectorSchema vector = new VectorSchema("embedding", DataType.VECTOR_FP32, dimension); + if (hnswIndexParams != null) { + vector = vector.withHnswIndex(hnswIndexParams); + } + return new CollectionSchema( + name, + List.of( + new FieldSchema("title", DataType.STRING, false), + new FieldSchema("bucket", DataType.INT64, false)), + List.of(vector)); + } + + public static List docs(int startDocIndex, int count, int dimension, long seed) { + List docs = new ArrayList<>(count); + for (int i = 0; i < count; i++) { + int docIndex = startDocIndex + i; + docs.add(doc(docIndex, dimension, seed)); + } + return docs; + } + + public static Doc doc(int docIndex, int dimension, long seed) { + return Doc.of(docId(docIndex)) + .field("title", "doc-" + docIndex) + .field("bucket", docIndex % 128L) + .vector("embedding", vector(docIndex, dimension, seed)); + } + + public static VectorSample querySample(int docIndex, int dimension, long seed, int topK) { + return new VectorSample(docId(docIndex), vector(docIndex, dimension, seed), topK); + } + + public static String docId(int docIndex) { + return "doc_" + docIndex; + } + + public static float[] vector(int docIndex, int dimension, long seed) { + float[] values = new float[dimension]; + fillVector(values, docIndex, seed); + return values; + } + + public static void fillVector(float[] values, int docIndex, long seed) { + for (int i = 0; i < values.length; i++) { + values[i] = hashedUnitFloat(docIndex, i, seed); + } + } + + private static float hashedUnitFloat(int docIndex, int dimIndex, long seed) { + long mixed = seed; + mixed ^= 0x9E3779B97F4A7C15L * (docIndex + 1L); + mixed ^= 0xBF58476D1CE4E5B9L * (dimIndex + 1L); + mixed = mix64(mixed); + return ((mixed >>> 40) & 0xFFFFFFL) / (float) 0x1000000L; + } + + private static long mix64(long z) { + z = (z ^ (z >>> 30)) * 0xBF58476D1CE4E5B9L; + z = (z ^ (z >>> 27)) * 0x94D049BB133111EBL; + return z ^ (z >>> 31); + } + + public static final class VectorSample { + private final String expectedId; + private final float[] vector; + private final int topK; + + public VectorSample(String expectedId, float[] vector, int topK) { + this.expectedId = Objects.requireNonNull(expectedId, "expectedId"); + this.vector = Objects.requireNonNull(vector, "vector"); + this.topK = topK; + } + + public String expectedId() { + return expectedId; + } + + public float[] vector() { + return vector; + } + + public int topK() { + return topK; + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/StressOptions.java b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/StressOptions.java new file mode 100644 index 000000000..ad9d35622 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/main/java/org/zvec/perf/StressOptions.java @@ -0,0 +1,334 @@ +package org.zvec.perf; + +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import org.zvec.HnswIndexParams; +import org.zvec.HnswQueryParams; + +public final class StressOptions { + private final int docCount; + private final int queryCount; + private final int batchSize; + private final int dimension; + private final int topK; + private final int warmupQueries; + private final int steadyStateRounds; + private final int steadyInsertBatchSize; + private final int steadyQueriesPerRound; + private final int concurrentQueryThreads; + private final int concurrentQueryCount; + private final int concurrentMixedThreads; + private final int concurrentMixedRounds; + private final int concurrentMixedInsertBatchSize; + private final int concurrentMixedQueriesPerRound; + private final HnswIndexParams hnswIndexParams; + private final HnswQueryParams hnswQueryParams; + private final long seed; + private final Path workDir; + + public StressOptions( + int docCount, + int queryCount, + int batchSize, + int dimension, + int topK, + int warmupQueries, + int steadyStateRounds, + int steadyInsertBatchSize, + int steadyQueriesPerRound, + int concurrentQueryThreads, + int concurrentQueryCount, + int concurrentMixedThreads, + int concurrentMixedRounds, + int concurrentMixedInsertBatchSize, + int concurrentMixedQueriesPerRound, + HnswIndexParams hnswIndexParams, + HnswQueryParams hnswQueryParams, + long seed, + Path workDir) { + requirePositive(docCount, "docCount"); + requirePositive(queryCount, "queryCount"); + requirePositive(batchSize, "batchSize"); + requirePositive(dimension, "dimension"); + requirePositive(topK, "topK"); + requirePositive(warmupQueries, "warmupQueries"); + requireNonNegative(steadyStateRounds, "steadyStateRounds"); + requirePositive(steadyInsertBatchSize, "steadyInsertBatchSize"); + requirePositive(steadyQueriesPerRound, "steadyQueriesPerRound"); + requirePositive(concurrentQueryThreads, "concurrentQueryThreads"); + requirePositive(concurrentQueryCount, "concurrentQueryCount"); + requirePositive(concurrentMixedThreads, "concurrentMixedThreads"); + requirePositive(concurrentMixedRounds, "concurrentMixedRounds"); + requirePositive(concurrentMixedInsertBatchSize, "concurrentMixedInsertBatchSize"); + requirePositive(concurrentMixedQueriesPerRound, "concurrentMixedQueriesPerRound"); + if (workDir == null) { + throw new IllegalArgumentException("workDir must not be null"); + } + this.docCount = docCount; + this.queryCount = queryCount; + this.batchSize = batchSize; + this.dimension = dimension; + this.topK = topK; + this.warmupQueries = warmupQueries; + this.steadyStateRounds = steadyStateRounds; + this.steadyInsertBatchSize = steadyInsertBatchSize; + this.steadyQueriesPerRound = steadyQueriesPerRound; + this.concurrentQueryThreads = concurrentQueryThreads; + this.concurrentQueryCount = concurrentQueryCount; + this.concurrentMixedThreads = concurrentMixedThreads; + this.concurrentMixedRounds = concurrentMixedRounds; + this.concurrentMixedInsertBatchSize = concurrentMixedInsertBatchSize; + this.concurrentMixedQueriesPerRound = concurrentMixedQueriesPerRound; + this.hnswIndexParams = hnswIndexParams; + this.hnswQueryParams = hnswQueryParams; + this.seed = seed; + this.workDir = workDir; + } + + public int docCount() { + return docCount; + } + + public int queryCount() { + return queryCount; + } + + public int batchSize() { + return batchSize; + } + + public int dimension() { + return dimension; + } + + public int topK() { + return topK; + } + + public int warmupQueries() { + return warmupQueries; + } + + public int steadyStateRounds() { + return steadyStateRounds; + } + + public int steadyInsertBatchSize() { + return steadyInsertBatchSize; + } + + public int steadyQueriesPerRound() { + return steadyQueriesPerRound; + } + + public int concurrentQueryThreads() { + return concurrentQueryThreads; + } + + public int concurrentQueryCount() { + return concurrentQueryCount; + } + + public int concurrentMixedThreads() { + return concurrentMixedThreads; + } + + public int concurrentMixedRounds() { + return concurrentMixedRounds; + } + + public int concurrentMixedInsertBatchSize() { + return concurrentMixedInsertBatchSize; + } + + public int concurrentMixedQueriesPerRound() { + return concurrentMixedQueriesPerRound; + } + + public HnswIndexParams hnswIndexParams() { + return hnswIndexParams; + } + + public HnswQueryParams hnswQueryParams() { + return hnswQueryParams; + } + + public long seed() { + return seed; + } + + public Path workDir() { + return workDir; + } + + public static StressOptions parse(String[] args) { + Map values = new HashMap<>(); + for (int i = 0; i < args.length; i++) { + String arg = args[i]; + if (!arg.startsWith("--")) { + throw new IllegalArgumentException("Expected option starting with --, got: " + arg); + } + if (i + 1 >= args.length) { + throw new IllegalArgumentException("Missing value for option: " + arg); + } + values.put(arg, args[++i]); + } + + validateKnownOptions(values); + + return new StressOptions( + parseInt(values, "--docs", 100_000), + parseInt(values, "--queries", 1_000), + parseInt(values, "--batch-size", 1_000), + parseInt(values, "--dimension", 128), + parseInt(values, "--top-k", 10), + parseInt(values, "--warmup-queries", 100), + parseInt(values, "--steady-state-rounds", 20), + parseInt(values, "--steady-insert-batch-size", 100), + parseInt(values, "--steady-queries-per-round", 20), + parseInt(values, "--concurrent-query-threads", 2), + parseInt(values, "--concurrent-query-count", 20), + parseInt(values, "--concurrent-mixed-threads", 2), + parseInt(values, "--concurrent-mixed-rounds", 2), + parseInt(values, "--concurrent-mixed-insert-batch-size", 5), + parseInt(values, "--concurrent-mixed-queries-per-round", 5), + parseHnswIndexParams(values), + parseHnswQueryParams(values), + parseLong(values, "--seed", 7L), + Path.of(values.getOrDefault("--work-dir", "target/perf/zvec-stress"))); + } + + private static void validateKnownOptions(Map values) { + for (String key : values.keySet()) { + if (!isKnownOption(key)) { + throw new IllegalArgumentException("Unknown option: " + key); + } + } + } + + private static boolean isKnownOption(String key) { + switch (key) { + case "--docs": + case "--queries": + case "--batch-size": + case "--dimension": + case "--top-k": + case "--warmup-queries": + case "--steady-state-rounds": + case "--steady-insert-batch-size": + case "--steady-queries-per-round": + case "--concurrent-query-threads": + case "--concurrent-query-count": + case "--concurrent-mixed-threads": + case "--concurrent-mixed-rounds": + case "--concurrent-mixed-insert-batch-size": + case "--concurrent-mixed-queries-per-round": + case "--hnsw-m": + case "--hnsw-ef-construction": + case "--hnsw-ef": + case "--seed": + case "--work-dir": + return true; + default: + return false; + } + } + + private static HnswIndexParams parseHnswIndexParams(Map values) { + boolean hasM = values.containsKey("--hnsw-m"); + boolean hasEfConstruction = values.containsKey("--hnsw-ef-construction"); + if (hasM != hasEfConstruction) { + throw new IllegalArgumentException( + "Both --hnsw-m and --hnsw-ef-construction are required together"); + } + if (!hasM) { + return null; + } + return new HnswIndexParams( + Integer.parseInt(values.get("--hnsw-m")), + Integer.parseInt(values.get("--hnsw-ef-construction"))); + } + + private static HnswQueryParams parseHnswQueryParams(Map values) { + if (!values.containsKey("--hnsw-ef")) { + return null; + } + return new HnswQueryParams(Integer.parseInt(values.get("--hnsw-ef")), 0.0f, false, false); + } + + private static int parseInt(Map values, String key, int defaultValue) { + return values.containsKey(key) ? Integer.parseInt(values.get(key)) : defaultValue; + } + + private static long parseLong(Map values, String key, long defaultValue) { + return values.containsKey(key) ? Long.parseLong(values.get(key)) : defaultValue; + } + + private static void requirePositive(int value, String name) { + if (value <= 0) { + throw new IllegalArgumentException(name + " must be > 0"); + } + } + + private static void requireNonNegative(int value, String name) { + if (value < 0) { + throw new IllegalArgumentException(name + " must be >= 0"); + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof StressOptions)) { + return false; + } + StressOptions other = (StressOptions) obj; + return docCount == other.docCount + && queryCount == other.queryCount + && batchSize == other.batchSize + && dimension == other.dimension + && topK == other.topK + && warmupQueries == other.warmupQueries + && steadyStateRounds == other.steadyStateRounds + && steadyInsertBatchSize == other.steadyInsertBatchSize + && steadyQueriesPerRound == other.steadyQueriesPerRound + && concurrentQueryThreads == other.concurrentQueryThreads + && concurrentQueryCount == other.concurrentQueryCount + && concurrentMixedThreads == other.concurrentMixedThreads + && concurrentMixedRounds == other.concurrentMixedRounds + && concurrentMixedInsertBatchSize == other.concurrentMixedInsertBatchSize + && concurrentMixedQueriesPerRound == other.concurrentMixedQueriesPerRound + && seed == other.seed + && Objects.equals(hnswIndexParams, other.hnswIndexParams) + && Objects.equals(hnswQueryParams, other.hnswQueryParams) + && workDir.equals(other.workDir); + } + + @Override + public int hashCode() { + return Objects.hash( + docCount, + queryCount, + batchSize, + dimension, + topK, + warmupQueries, + steadyStateRounds, + steadyInsertBatchSize, + steadyQueriesPerRound, + concurrentQueryThreads, + concurrentQueryCount, + concurrentMixedThreads, + concurrentMixedRounds, + concurrentMixedInsertBatchSize, + concurrentMixedQueriesPerRound, + hnswIndexParams, + hnswQueryParams, + seed, + workDir); + } +} diff --git a/java/zvec-java/zvec-java-ffm/pom.xml b/java/zvec-java/zvec-java-ffm/pom.xml new file mode 100644 index 000000000..b08d72b43 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/pom.xml @@ -0,0 +1,104 @@ + + 4.0.0 + + + org.zvec + zvec-java-parent + 0.0.1-SNAPSHOT + + + zvec-java-ffm + zvec-java-ffm + + + 25 + + + + + org.zvec + zvec-java-api + ${project.version} + + + org.zvec + zvec-java-api + ${project.version} + test-jar + tests + test + + + org.junit.jupiter + junit-jupiter + test + + + org.openjdk.jmh + jmh-core + test + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + ${maven.compiler.release} + + + + org.codehaus.mojo + exec-maven-plugin + + + build-native-library + process-resources + + exec + + + ${project.basedir}/../../../scripts/build_java_native.sh + + ${project.build.outputDirectory}/META-INF/native + ffm + ${zvec.native.platform} + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + false + --enable-native-access=ALL-UNNAMED -Dorg.zvec.backend=ffm + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + require-java-25 + + enforce + + + + + [25,26) + + + + + + + + + diff --git a/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmCollections.java b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmCollections.java new file mode 100644 index 000000000..547d7336e --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmCollections.java @@ -0,0 +1,225 @@ +package org.zvec.internal.ffm; + +import org.zvec.internal.ZvecException; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_LONG; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.List; +import java.util.Objects; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.VectorQuery; +import org.zvec.VectorSchema; +import org.zvec.internal.NativeHandle; +import org.zvec.internal.NativeOpenResult; + +public final class FfmCollections { + private FfmCollections() {} + + public static NativeOpenResult createAndOpen(String path, CollectionSchema schema) { + Objects.requireNonNull(path, "path"); + Objects.requireNonNull(schema, "schema"); + + MemorySegment nativeSchema = FfmSchemas.toNative(schema); + MemorySegment handle = MemorySegment.NULL; + try (Arena arena = Arena.ofConfined()) { + MemorySegment outCollection = arena.allocate(ADDRESS); + FfmNative.check( + (int) + FfmNative.handleCollectionCreateAndOpen() + .invokeExact(arena.allocateFrom(path), nativeSchema, MemorySegment.NULL, outCollection), + "zvec_collection_create_and_open"); + handle = outCollection.get(ADDRESS, 0); + CollectionSchema querySchema = readSchema(handle); + return new NativeOpenResult(new FfmHandle(handle), querySchema); + } catch (Throwable t) { + closeQuietly(handle); + throw propagate("Failed to create collection", t); + } finally { + FfmSchemas.destroy(nativeSchema); + } + } + + public static NativeOpenResult open(String path) { + Objects.requireNonNull(path, "path"); + + MemorySegment handle = MemorySegment.NULL; + try (Arena arena = Arena.ofConfined()) { + MemorySegment outCollection = arena.allocate(ADDRESS); + FfmNative.check( + (int) + FfmNative.handleCollectionOpen() + .invokeExact(arena.allocateFrom(path), MemorySegment.NULL, outCollection), + "zvec_collection_open"); + handle = outCollection.get(ADDRESS, 0); + CollectionSchema querySchema = readSchema(handle); + return new NativeOpenResult(new FfmHandle(handle), querySchema); + } catch (Throwable t) { + closeQuietly(handle); + throw propagate("Failed to open collection", t); + } + } + + public static void close(NativeHandle handle) { + try { + FfmNative.check( + (int) FfmNative.handleCollectionClose().invokeExact(segment(handle)), "zvec_collection_close"); + } catch (Throwable t) { + throw propagate("Failed to close collection", t); + } + } + + public static void flush(NativeHandle handle) { + try { + FfmNative.check( + (int) FfmNative.handleCollectionFlush().invokeExact(segment(handle)), "zvec_collection_flush"); + } catch (Throwable t) { + throw propagate("Failed to flush collection", t); + } + } + + public static CollectionSchema readSchema(NativeHandle handle) { + return readSchema(segment(handle)); + } + + private static CollectionSchema readSchema(MemorySegment handle) { + MemorySegment schemaHandle = MemorySegment.NULL; + try (Arena arena = Arena.ofConfined()) { + MemorySegment outSchema = arena.allocate(ADDRESS); + FfmNative.check( + (int) FfmNative.handleCollectionGetSchema().invokeExact(handle, outSchema), + "zvec_collection_get_schema"); + schemaHandle = outSchema.get(ADDRESS, 0); + return FfmSchemas.fromNative(schemaHandle); + } catch (Throwable t) { + throw propagate("Failed to read collection schema", t); + } finally { + FfmSchemas.destroy(schemaHandle); + } + } + + public static int insert(NativeHandle collectionHandle, CollectionSchema schema, List docs) { + MemorySegment collectionSegment = segment(collectionHandle); + Objects.requireNonNull(schema, "schema"); + Objects.requireNonNull(docs, "docs"); + if (docs.isEmpty()) { + return 0; + } + + List nativeDocs = FfmDocs.toFfmDocs(docs, schema); + try (Arena arena = Arena.ofConfined()) { + MemorySegment docArray = arena.allocate(ADDRESS, nativeDocs.size()); + for (int i = 0; i < nativeDocs.size(); i++) { + docArray.setAtIndex(ADDRESS, i, nativeDocs.get(i)); + } + + MemorySegment successCount = arena.allocate(JAVA_LONG); + MemorySegment errorCount = arena.allocate(JAVA_LONG); + FfmNative.check( + (int) + FfmNative.handleCollectionInsert() + .invokeExact( + collectionSegment, + docArray, + (long) nativeDocs.size(), + successCount, + errorCount), + "zvec_collection_insert"); + + long errors = errorCount.get(JAVA_LONG, 0); + if (errors != 0L) { + throw new ZvecException( + -1, "zvec_collection_insert reported " + errors + " per-document failures"); + } + return Math.toIntExact(successCount.get(JAVA_LONG, 0)); + } catch (Throwable t) { + throw propagate("Failed to insert documents", t); + } finally { + FfmDocs.destroyAll(nativeDocs); + } + } + + public static List query( + NativeHandle collectionHandle, + CollectionSchema querySchema, + CollectionSchema resultSchema, + VectorQuery query) { + MemorySegment collectionSegment = segment(collectionHandle); + Objects.requireNonNull(querySchema, "querySchema"); + Objects.requireNonNull(resultSchema, "resultSchema"); + Objects.requireNonNull(query, "query"); + + VectorSchema runtimeVectorSchema = querySchema.vector(query.fieldName()); + if (runtimeVectorSchema == null) { + throw new IllegalArgumentException("Unknown vector field: " + query.fieldName()); + } + VectorSchema publicVectorSchema = resultSchema.vector(query.fieldName()); + if (publicVectorSchema == null) { + throw new IllegalArgumentException("Unknown vector field: " + query.fieldName()); + } + MemorySegment nativeQuery = FfmQueries.toNative(runtimeVectorSchema, publicVectorSchema, query); + MemorySegment nativeResults = MemorySegment.NULL; + long resultCount = 0L; + try (Arena arena = Arena.ofConfined()) { + MemorySegment outResults = arena.allocate(ADDRESS); + MemorySegment outResultCount = arena.allocate(JAVA_LONG); + FfmNative.check( + (int) + FfmNative.handleCollectionQuery() + .invokeExact(collectionSegment, nativeQuery, outResults, outResultCount), + "zvec_collection_query"); + + nativeResults = outResults.get(ADDRESS, 0); + resultCount = outResultCount.get(JAVA_LONG, 0); + if (resultCount == 0) { + return List.of(); + } + return FfmDocs.fromNativeQueryDocs(nativeResults, resultCount, resultSchema); + } catch (Throwable t) { + throw propagate("Failed to query documents", t); + } finally { + freeFfmDocsQuietly(nativeResults, resultCount); + FfmQueries.destroy(nativeQuery); + } + } + + private static void closeQuietly(MemorySegment handle) { + if (handle == null || handle.equals(MemorySegment.NULL)) { + return; + } + + try { + FfmNative.check( + (int) FfmNative.handleCollectionClose().invokeExact(handle), "zvec_collection_close"); + } catch (Throwable ignored) { + } + } + + private static void freeFfmDocsQuietly(MemorySegment docs, long count) { + if (docs == null || docs.equals(MemorySegment.NULL)) { + return; + } + try { + FfmNative.handleDocsFree().invokeExact(docs, count); + } catch (Throwable ignored) { + } + } + + private static MemorySegment segment(NativeHandle handle) { + Objects.requireNonNull(handle, "handle"); + if (handle instanceof FfmHandle ffmHandle) { + return ffmHandle.segment(); + } + throw new IllegalArgumentException("Handle is not an FFM handle: " + handle.getClass().getName()); + } + + private static RuntimeException propagate(String message, Throwable cause) { + if (cause instanceof RuntimeException runtimeException) { + return runtimeException; + } + return new IllegalStateException(message, cause); + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmDocs.java b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmDocs.java new file mode 100644 index 000000000..037fdf277 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmDocs.java @@ -0,0 +1,428 @@ +package org.zvec.internal.ffm; + +import org.zvec.internal.ZvecException; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_BOOLEAN; +import static java.lang.foreign.ValueLayout.JAVA_BYTE; +import static java.lang.foreign.ValueLayout.JAVA_DOUBLE; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.FieldSchema; +import org.zvec.VectorSchema; + +public final class FfmDocs { + private FfmDocs() {} + + public static List toFfmDocs(List docs, CollectionSchema schema) { + Objects.requireNonNull(docs, "docs"); + Objects.requireNonNull(schema, "schema"); + + List nativeDocs = new ArrayList<>(docs.size()); + try { + for (Doc doc : docs) { + nativeDocs.add(toNativeDoc(Objects.requireNonNull(doc, "doc"), schema)); + } + return nativeDocs; + } catch (RuntimeException e) { + destroyAll(nativeDocs); + throw e; + } + } + + public static MemorySegment toNativeDoc(Doc doc, CollectionSchema schema) { + Objects.requireNonNull(doc, "doc"); + Objects.requireNonNull(schema, "schema"); + + MemorySegment nativeDoc = MemorySegment.NULL; + try (Arena arena = Arena.ofConfined()) { + nativeDoc = (MemorySegment) FfmNative.handleDocCreate().invokeExact(); + if (nativeDoc.equals(MemorySegment.NULL)) { + throw FfmNative.lastError("zvec_doc_create", -1); + } + + FfmNative.handleDocSetPk().invokeExact(nativeDoc, arena.allocateFrom(doc.id())); + + for (Map.Entry entry : doc.fields().entrySet()) { + FieldSchema field = schema.field(entry.getKey()); + if (field == null) { + throw new IllegalArgumentException("Unknown scalar field: " + entry.getKey()); + } + writeScalarField(nativeDoc, field, entry.getValue(), arena); + } + + for (String fieldName : doc.nullFields()) { + FieldSchema field = schema.field(fieldName); + if (field == null) { + throw new IllegalArgumentException("Unknown scalar field: " + fieldName); + } + if (!field.nullable()) { + throw new IllegalArgumentException("Field is not nullable: " + fieldName); + } + FfmNative.check( + (int) + FfmNative.handleDocSetFieldNull() + .invokeExact(nativeDoc, arena.allocateFrom(fieldName)), + "zvec_doc_set_field_null"); + } + + for (Map.Entry entry : doc.vectors().entrySet()) { + String vectorName = entry.getKey(); + VectorSchema vector = schema.vector(vectorName); + if (vector == null) { + throw new IllegalArgumentException("Unknown vector field: " + vectorName); + } + + float[] values = Objects.requireNonNull(entry.getValue(), "vector values"); + if (values.length != vector.dimension()) { + throw new IllegalArgumentException( + "Vector dimension mismatch for field " + + vectorName + + ": expected " + + vector.dimension() + + ", got " + + values.length); + } + + MemorySegment valueSegment = arena.allocateFrom(JAVA_FLOAT, values); + FfmNative.check( + (int) + FfmNative.handleDocAddFieldByValue() + .invokeExact( + nativeDoc, + arena.allocateFrom(vectorName), + vector.dataType().code(), + valueSegment, + (long) values.length * Float.BYTES), + "zvec_doc_add_field_by_value"); + } + + return nativeDoc; + } catch (Throwable t) { + destroyQuietly(nativeDoc); + throw propagate("Failed to convert doc to native", t); + } + } + + public static List fromNativeQueryDocs( + MemorySegment nativeDocs, long count, CollectionSchema schema) { + Objects.requireNonNull(nativeDocs, "nativeDocs"); + Objects.requireNonNull(schema, "schema"); + if (count == 0) { + return List.of(); + } + + MemorySegment docsArray = nativeDocs.reinterpret(count * ADDRESS.byteSize()); + List docs = new ArrayList<>(Math.toIntExact(count)); + for (long i = 0; i < count; i++) { + docs.add(fromNativeQueryDoc(docsArray.getAtIndex(ADDRESS, i), schema)); + } + return docs; + } + + public static void destroyAll(List docs) { + if (docs == null) { + return; + } + for (MemorySegment doc : docs) { + destroy(doc); + } + } + + public static void destroy(MemorySegment doc) { + if (doc == null || doc.equals(MemorySegment.NULL)) { + return; + } + try { + FfmNative.handleDocDestroy().invokeExact(doc); + } catch (Throwable t) { + throw propagate("Failed to destroy native doc", t); + } + } + + private static Doc fromNativeQueryDoc(MemorySegment nativeDoc, CollectionSchema schema) { + MemorySegment pkCopy = MemorySegment.NULL; + try { + pkCopy = (MemorySegment) FfmNative.handleDocGetPkCopy().invokeExact(nativeDoc); + String id = pkCopy.equals(MemorySegment.NULL) ? "" : FfmNative.readUtf8CString(pkCopy); + double score = (float) FfmNative.handleDocGetScore().invokeExact(nativeDoc); + + Doc doc = Doc.result(id, score); + for (String fieldName : readFieldNames(nativeDoc)) { + FieldSchema scalarField = schema.field(fieldName); + if (scalarField != null) { + readScalarField(nativeDoc, doc, scalarField); + continue; + } + + VectorSchema vectorField = schema.vector(fieldName); + if (vectorField != null) { + readVectorField(nativeDoc, doc, vectorField); + } + } + return doc; + } catch (Throwable t) { + throw propagate("Failed to convert native query doc", t); + } finally { + FfmNative.free(pkCopy); + } + } + + private static List readFieldNames(MemorySegment nativeDoc) { + MemorySegment namesArray = MemorySegment.NULL; + long count = 0L; + try (Arena arena = Arena.ofConfined()) { + MemorySegment outNames = arena.allocate(ADDRESS); + MemorySegment outCount = arena.allocate(JAVA_LONG); + FfmNative.check( + (int) FfmNative.handleDocGetFieldNames().invokeExact(nativeDoc, outNames, outCount), + "zvec_doc_get_field_names"); + namesArray = outNames.get(ADDRESS, 0); + count = outCount.get(JAVA_LONG, 0); + + if (count == 0) { + return List.of(); + } + + MemorySegment names = namesArray.reinterpret(count * ADDRESS.byteSize()); + List fieldNames = new ArrayList<>(Math.toIntExact(count)); + for (long i = 0; i < count; i++) { + fieldNames.add(FfmNative.readUtf8CString(names.getAtIndex(ADDRESS, i))); + } + return fieldNames; + } catch (Throwable t) { + throw propagate("Failed to read native doc field names", t); + } finally { + freeStrArrayQuietly(namesArray, count); + } + } + + private static void writeScalarField( + MemorySegment nativeDoc, FieldSchema field, Object value, Arena arena) throws Throwable { + MemorySegment fieldName = arena.allocateFrom(field.name()); + switch (field.dataType()) { + case STRING -> { + if (!(value instanceof String stringValue)) { + throw new IllegalArgumentException("Field " + field.name() + " expects STRING"); + } + byte[] utf8 = stringValue.getBytes(StandardCharsets.UTF_8); + MemorySegment valueSegment = arena.allocateFrom(JAVA_BYTE, utf8); + FfmNative.check( + (int) + FfmNative.handleDocAddFieldByValue() + .invokeExact( + nativeDoc, + fieldName, + field.dataType().code(), + valueSegment, + (long) utf8.length), + "zvec_doc_add_field_by_value"); + } + case BOOL -> { + if (!(value instanceof Boolean boolValue)) { + throw new IllegalArgumentException("Field " + field.name() + " expects BOOL"); + } + MemorySegment valueSegment = arena.allocate(JAVA_BOOLEAN); + valueSegment.set(JAVA_BOOLEAN, 0, boolValue); + FfmNative.check( + (int) + FfmNative.handleDocAddFieldByValue() + .invokeExact( + nativeDoc, + fieldName, + field.dataType().code(), + valueSegment, + JAVA_BOOLEAN.byteSize()), + "zvec_doc_add_field_by_value"); + } + case INT64 -> { + if (!(value instanceof Long longValue)) { + throw new IllegalArgumentException("Field " + field.name() + " expects INT64"); + } + MemorySegment valueSegment = arena.allocate(JAVA_LONG); + valueSegment.set(JAVA_LONG, 0, longValue); + FfmNative.check( + (int) + FfmNative.handleDocAddFieldByValue() + .invokeExact( + nativeDoc, + fieldName, + field.dataType().code(), + valueSegment, + JAVA_LONG.byteSize()), + "zvec_doc_add_field_by_value"); + } + case DOUBLE -> { + if (!(value instanceof Double doubleValue)) { + throw new IllegalArgumentException("Field " + field.name() + " expects DOUBLE"); + } + MemorySegment valueSegment = arena.allocate(JAVA_DOUBLE); + valueSegment.set(JAVA_DOUBLE, 0, doubleValue); + FfmNative.check( + (int) + FfmNative.handleDocAddFieldByValue() + .invokeExact( + nativeDoc, + fieldName, + field.dataType().code(), + valueSegment, + JAVA_DOUBLE.byteSize()), + "zvec_doc_add_field_by_value"); + } + default -> throw new IllegalArgumentException("Unsupported scalar type: " + field.dataType()); + } + } + + private static void readScalarField(MemorySegment nativeDoc, Doc target, FieldSchema field) { + try (Arena arena = Arena.ofConfined()) { + if (isFieldNull(nativeDoc, field.name(), arena)) { + target.nullField(field.name()); + return; + } + + ValueCopy valueCopy = + readFieldValueCopy(nativeDoc, field.name(), field.dataType().code(), arena); + try { + switch (field.dataType()) { + case STRING -> target.field(field.name(), readString(valueCopy)); + case BOOL -> target.field(field.name(), readBoolean(valueCopy)); + case INT64 -> target.field(field.name(), readInt64(valueCopy)); + case DOUBLE -> target.field(field.name(), readDouble(valueCopy)); + default -> throw new IllegalArgumentException("Unsupported scalar type: " + field.dataType()); + } + } finally { + FfmNative.free(valueCopy.valuePtr); + } + } catch (Throwable t) { + throw propagate("Failed to read scalar query field: " + field.name(), t); + } + } + + private static boolean isFieldNull(MemorySegment nativeDoc, String fieldName, Arena arena) + throws Throwable { + return (boolean) + FfmNative.handleDocIsFieldNull().invokeExact(nativeDoc, arena.allocateFrom(fieldName)); + } + + private static void readVectorField(MemorySegment nativeDoc, Doc target, VectorSchema vectorField) { + if (vectorField.dataType() != org.zvec.DataType.VECTOR_FP32) { + throw new IllegalArgumentException("Unsupported vector type: " + vectorField.dataType()); + } + + try (Arena arena = Arena.ofConfined()) { + ValueCopy valueCopy = + readFieldValueCopy(nativeDoc, vectorField.name(), vectorField.dataType().code(), arena); + try { + target.vector(vectorField.name(), readFp32Vector(valueCopy)); + } finally { + FfmNative.free(valueCopy.valuePtr); + } + } catch (Throwable t) { + throw propagate("Failed to read vector query field: " + vectorField.name(), t); + } + } + + private static ValueCopy readFieldValueCopy( + MemorySegment nativeDoc, String fieldName, int dataTypeCode, Arena arena) throws Throwable { + MemorySegment outValue = arena.allocate(ADDRESS); + MemorySegment outSize = arena.allocate(JAVA_LONG); + FfmNative.check( + (int) + FfmNative.handleDocGetFieldValueCopy() + .invokeExact(nativeDoc, arena.allocateFrom(fieldName), dataTypeCode, outValue, outSize), + "zvec_doc_get_field_value_copy"); + return new ValueCopy(outValue.get(ADDRESS, 0), outSize.get(JAVA_LONG, 0)); + } + + private static String readString(ValueCopy valueCopy) { + if (valueCopy.valueSize == 0) { + return ""; + } + if (valueCopy.valuePtr.equals(MemorySegment.NULL)) { + throw new IllegalStateException("String value pointer is null"); + } + byte[] bytes = valueCopy.valuePtr.reinterpret(valueCopy.valueSize).toArray(JAVA_BYTE); + return new String(bytes, StandardCharsets.UTF_8); + } + + private static boolean readBoolean(ValueCopy valueCopy) { + if (valueCopy.valueSize < JAVA_BOOLEAN.byteSize() || valueCopy.valuePtr.equals(MemorySegment.NULL)) { + throw new IllegalStateException("Invalid BOOL value"); + } + return valueCopy.valuePtr.get(JAVA_BOOLEAN, 0); + } + + private static long readInt64(ValueCopy valueCopy) { + if (valueCopy.valueSize < JAVA_LONG.byteSize() || valueCopy.valuePtr.equals(MemorySegment.NULL)) { + throw new IllegalStateException("Invalid INT64 value"); + } + return valueCopy.valuePtr.get(JAVA_LONG, 0); + } + + private static double readDouble(ValueCopy valueCopy) { + if (valueCopy.valueSize < JAVA_DOUBLE.byteSize() || valueCopy.valuePtr.equals(MemorySegment.NULL)) { + throw new IllegalStateException("Invalid DOUBLE value"); + } + return valueCopy.valuePtr.get(JAVA_DOUBLE, 0); + } + + private static float[] readFp32Vector(ValueCopy valueCopy) { + if (valueCopy.valueSize == 0) { + return new float[0]; + } + if (valueCopy.valuePtr.equals(MemorySegment.NULL)) { + throw new IllegalStateException("Vector value pointer is null"); + } + if (valueCopy.valueSize % Float.BYTES != 0) { + throw new IllegalStateException("Invalid VECTOR_FP32 byte size: " + valueCopy.valueSize); + } + return valueCopy.valuePtr.reinterpret(valueCopy.valueSize).toArray(JAVA_FLOAT); + } + + private static void destroyQuietly(MemorySegment nativeDoc) { + if (nativeDoc == null || nativeDoc.equals(MemorySegment.NULL)) { + return; + } + try { + FfmNative.handleDocDestroy().invokeExact(nativeDoc); + } catch (Throwable ignored) { + } + } + + private static void freeStrArrayQuietly(MemorySegment namesArray, long count) { + if (namesArray == null || namesArray.equals(MemorySegment.NULL)) { + return; + } + try { + FfmNative.handleFreeStrArray().invokeExact(namesArray, count); + } catch (Throwable ignored) { + } + } + + private static RuntimeException propagate(String message, Throwable cause) { + if (cause instanceof RuntimeException runtimeException) { + return runtimeException; + } + return new IllegalStateException(message, cause); + } + + private static final class ValueCopy { + private final MemorySegment valuePtr; + private final long valueSize; + + private ValueCopy(MemorySegment valuePtr, long valueSize) { + this.valuePtr = valuePtr; + this.valueSize = valueSize; + } + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmHandle.java b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmHandle.java new file mode 100644 index 000000000..26a3edfc6 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmHandle.java @@ -0,0 +1,11 @@ +package org.zvec.internal.ffm; + +import java.lang.foreign.MemorySegment; +import java.util.Objects; +import org.zvec.internal.NativeHandle; + +record FfmHandle(MemorySegment segment) implements NativeHandle { + FfmHandle { + Objects.requireNonNull(segment, "segment"); + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNative.java b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNative.java new file mode 100644 index 000000000..4a0652975 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNative.java @@ -0,0 +1,511 @@ +package org.zvec.internal.ffm; + +import org.zvec.internal.ZvecException; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_BOOLEAN; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; + +import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.Linker; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SymbolLookup; +import java.lang.invoke.MethodHandle; +import java.util.concurrent.atomic.AtomicBoolean; + +public final class FfmNative { + public static final int ZVEC_OK = 0; + static final int ZVEC_INDEX_TYPE_HNSW = 1; + static final int ZVEC_METRIC_TYPE_L2 = 1; + private static final long VERSION_STRING_MAX_BYTES = 256; + private static final Linker LINKER = Linker.nativeLinker(); + private static final SymbolLookup LOOKUP; + private static final AtomicBoolean SHUTDOWN_HOOK_REGISTERED = new AtomicBoolean(false); + private static final Object INIT_LOCK = new Object(); + + private static final MethodHandle ZVEC_GET_VERSION; + private static final MethodHandle ZVEC_IS_INITIALIZED; + private static final MethodHandle ZVEC_INITIALIZE; + private static final MethodHandle ZVEC_SHUTDOWN; + private static final MethodHandle ZVEC_GET_LAST_ERROR; + private static final MethodHandle ZVEC_FREE; + private static final MethodHandle ZVEC_COLLECTION_SCHEMA_CREATE; + private static final MethodHandle ZVEC_COLLECTION_SCHEMA_DESTROY; + private static final MethodHandle ZVEC_COLLECTION_SCHEMA_ADD_FIELD; + private static final MethodHandle ZVEC_COLLECTION_SCHEMA_GET_NAME; + private static final MethodHandle ZVEC_COLLECTION_SCHEMA_GET_FORWARD_FIELDS; + private static final MethodHandle ZVEC_COLLECTION_SCHEMA_GET_VECTOR_FIELDS; + private static final MethodHandle ZVEC_FIELD_SCHEMA_CREATE; + private static final MethodHandle ZVEC_FIELD_SCHEMA_DESTROY; + private static final MethodHandle ZVEC_FIELD_SCHEMA_GET_NAME; + private static final MethodHandle ZVEC_FIELD_SCHEMA_GET_DATA_TYPE; + private static final MethodHandle ZVEC_FIELD_SCHEMA_IS_NULLABLE; + private static final MethodHandle ZVEC_FIELD_SCHEMA_GET_DIMENSION; + private static final MethodHandle ZVEC_FIELD_SCHEMA_GET_INDEX_PARAMS; + private static final MethodHandle ZVEC_FIELD_SCHEMA_SET_INDEX_PARAMS; + private static final MethodHandle ZVEC_INDEX_PARAMS_CREATE; + private static final MethodHandle ZVEC_INDEX_PARAMS_DESTROY; + private static final MethodHandle ZVEC_INDEX_PARAMS_GET_TYPE; + private static final MethodHandle ZVEC_INDEX_PARAMS_SET_METRIC_TYPE; + private static final MethodHandle ZVEC_INDEX_PARAMS_SET_HNSW_PARAMS; + private static final MethodHandle ZVEC_INDEX_PARAMS_GET_HNSW_M; + private static final MethodHandle ZVEC_INDEX_PARAMS_GET_HNSW_EF_CONSTRUCTION; + private static final MethodHandle ZVEC_COLLECTION_CREATE_AND_OPEN; + private static final MethodHandle ZVEC_COLLECTION_OPEN; + private static final MethodHandle ZVEC_COLLECTION_CLOSE; + private static final MethodHandle ZVEC_COLLECTION_FLUSH; + private static final MethodHandle ZVEC_COLLECTION_GET_SCHEMA; + private static final MethodHandle ZVEC_COLLECTION_INSERT; + private static final MethodHandle ZVEC_COLLECTION_QUERY; + private static final MethodHandle ZVEC_DOC_CREATE; + private static final MethodHandle ZVEC_DOC_DESTROY; + private static final MethodHandle ZVEC_DOCS_FREE; + private static final MethodHandle ZVEC_DOC_SET_PK; + private static final MethodHandle ZVEC_DOC_SET_FIELD_NULL; + private static final MethodHandle ZVEC_DOC_ADD_FIELD_BY_VALUE; + private static final MethodHandle ZVEC_DOC_GET_PK_COPY; + private static final MethodHandle ZVEC_DOC_GET_SCORE; + private static final MethodHandle ZVEC_DOC_GET_FIELD_NAMES; + private static final MethodHandle ZVEC_DOC_IS_FIELD_NULL; + private static final MethodHandle ZVEC_DOC_GET_FIELD_VALUE_COPY; + private static final MethodHandle ZVEC_FREE_STR_ARRAY; + private static final MethodHandle ZVEC_VECTOR_QUERY_CREATE; + private static final MethodHandle ZVEC_VECTOR_QUERY_DESTROY; + private static final MethodHandle ZVEC_VECTOR_QUERY_SET_TOPK; + private static final MethodHandle ZVEC_VECTOR_QUERY_SET_FIELD_NAME; + private static final MethodHandle ZVEC_VECTOR_QUERY_SET_QUERY_VECTOR; + private static final MethodHandle ZVEC_VECTOR_QUERY_SET_FILTER; + private static final MethodHandle ZVEC_VECTOR_QUERY_SET_INCLUDE_VECTOR; + private static final MethodHandle ZVEC_VECTOR_QUERY_SET_OUTPUT_FIELDS; + private static final MethodHandle ZVEC_QUERY_PARAMS_HNSW_CREATE; + private static final MethodHandle ZVEC_QUERY_PARAMS_HNSW_DESTROY; + private static final MethodHandle ZVEC_VECTOR_QUERY_SET_HNSW_PARAMS; + + static { + FfmNativeLoader.load(); + LOOKUP = SymbolLookup.loaderLookup(); + ZVEC_GET_VERSION = downcall("zvec_get_version", FunctionDescriptor.of(ADDRESS)); + ZVEC_IS_INITIALIZED = downcall("zvec_is_initialized", FunctionDescriptor.of(JAVA_BOOLEAN)); + ZVEC_INITIALIZE = downcall("zvec_initialize", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + ZVEC_SHUTDOWN = downcall("zvec_shutdown", FunctionDescriptor.of(JAVA_INT)); + ZVEC_GET_LAST_ERROR = downcall("zvec_get_last_error", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + ZVEC_FREE = downcall("zvec_free", FunctionDescriptor.ofVoid(ADDRESS)); + ZVEC_COLLECTION_SCHEMA_CREATE = + downcall("zvec_collection_schema_create", FunctionDescriptor.of(ADDRESS, ADDRESS)); + ZVEC_COLLECTION_SCHEMA_DESTROY = + downcall("zvec_collection_schema_destroy", FunctionDescriptor.ofVoid(ADDRESS)); + ZVEC_COLLECTION_SCHEMA_ADD_FIELD = + downcall("zvec_collection_schema_add_field", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + ZVEC_COLLECTION_SCHEMA_GET_NAME = + downcall("zvec_collection_schema_get_name", FunctionDescriptor.of(ADDRESS, ADDRESS)); + ZVEC_COLLECTION_SCHEMA_GET_FORWARD_FIELDS = + downcall( + "zvec_collection_schema_get_forward_fields", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS)); + ZVEC_COLLECTION_SCHEMA_GET_VECTOR_FIELDS = + downcall( + "zvec_collection_schema_get_vector_fields", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS)); + ZVEC_FIELD_SCHEMA_CREATE = + downcall( + "zvec_field_schema_create", + FunctionDescriptor.of(ADDRESS, ADDRESS, JAVA_INT, JAVA_BOOLEAN, JAVA_INT)); + ZVEC_FIELD_SCHEMA_DESTROY = downcall("zvec_field_schema_destroy", FunctionDescriptor.ofVoid(ADDRESS)); + ZVEC_FIELD_SCHEMA_GET_NAME = + downcall("zvec_field_schema_get_name", FunctionDescriptor.of(ADDRESS, ADDRESS)); + ZVEC_FIELD_SCHEMA_GET_DATA_TYPE = + downcall("zvec_field_schema_get_data_type", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + ZVEC_FIELD_SCHEMA_IS_NULLABLE = + downcall("zvec_field_schema_is_nullable", FunctionDescriptor.of(JAVA_BOOLEAN, ADDRESS)); + ZVEC_FIELD_SCHEMA_GET_DIMENSION = + downcall("zvec_field_schema_get_dimension", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + ZVEC_FIELD_SCHEMA_GET_INDEX_PARAMS = + downcall("zvec_field_schema_get_index_params", FunctionDescriptor.of(ADDRESS, ADDRESS)); + ZVEC_FIELD_SCHEMA_SET_INDEX_PARAMS = + downcall("zvec_field_schema_set_index_params", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + ZVEC_INDEX_PARAMS_CREATE = + downcall("zvec_index_params_create", FunctionDescriptor.of(ADDRESS, JAVA_INT)); + ZVEC_INDEX_PARAMS_DESTROY = downcall("zvec_index_params_destroy", FunctionDescriptor.ofVoid(ADDRESS)); + ZVEC_INDEX_PARAMS_GET_TYPE = + downcall("zvec_index_params_get_type", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + ZVEC_INDEX_PARAMS_SET_METRIC_TYPE = + downcall("zvec_index_params_set_metric_type", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT)); + ZVEC_INDEX_PARAMS_SET_HNSW_PARAMS = + downcall("zvec_index_params_set_hnsw_params", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, JAVA_INT)); + ZVEC_INDEX_PARAMS_GET_HNSW_M = + downcall("zvec_index_params_get_hnsw_m", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + ZVEC_INDEX_PARAMS_GET_HNSW_EF_CONSTRUCTION = + downcall("zvec_index_params_get_hnsw_ef_construction", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + ZVEC_COLLECTION_CREATE_AND_OPEN = + downcall( + "zvec_collection_create_and_open", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS)); + ZVEC_COLLECTION_OPEN = + downcall("zvec_collection_open", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS)); + ZVEC_COLLECTION_CLOSE = downcall("zvec_collection_close", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + ZVEC_COLLECTION_FLUSH = downcall("zvec_collection_flush", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + ZVEC_COLLECTION_GET_SCHEMA = + downcall("zvec_collection_get_schema", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + ZVEC_COLLECTION_INSERT = + downcall( + "zvec_collection_insert", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS)); + ZVEC_COLLECTION_QUERY = + downcall( + "zvec_collection_query", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS)); + ZVEC_DOC_CREATE = downcall("zvec_doc_create", FunctionDescriptor.of(ADDRESS)); + ZVEC_DOC_DESTROY = downcall("zvec_doc_destroy", FunctionDescriptor.ofVoid(ADDRESS)); + ZVEC_DOCS_FREE = downcall("zvec_docs_free", FunctionDescriptor.ofVoid(ADDRESS, JAVA_LONG)); + ZVEC_DOC_SET_PK = downcall("zvec_doc_set_pk", FunctionDescriptor.ofVoid(ADDRESS, ADDRESS)); + ZVEC_DOC_SET_FIELD_NULL = + downcall("zvec_doc_set_field_null", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + ZVEC_DOC_ADD_FIELD_BY_VALUE = + downcall( + "zvec_doc_add_field_by_value", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, ADDRESS, JAVA_LONG)); + ZVEC_DOC_GET_PK_COPY = downcall("zvec_doc_get_pk_copy", FunctionDescriptor.of(ADDRESS, ADDRESS)); + ZVEC_DOC_GET_SCORE = downcall("zvec_doc_get_score", FunctionDescriptor.of(JAVA_FLOAT, ADDRESS)); + ZVEC_DOC_GET_FIELD_NAMES = + downcall("zvec_doc_get_field_names", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS)); + ZVEC_DOC_IS_FIELD_NULL = + downcall("zvec_doc_is_field_null", FunctionDescriptor.of(JAVA_BOOLEAN, ADDRESS, ADDRESS)); + ZVEC_DOC_GET_FIELD_VALUE_COPY = + downcall( + "zvec_doc_get_field_value_copy", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT, ADDRESS, ADDRESS)); + ZVEC_FREE_STR_ARRAY = downcall("zvec_free_str_array", FunctionDescriptor.ofVoid(ADDRESS, JAVA_LONG)); + ZVEC_VECTOR_QUERY_CREATE = downcall("zvec_vector_query_create", FunctionDescriptor.of(ADDRESS)); + ZVEC_VECTOR_QUERY_DESTROY = + downcall("zvec_vector_query_destroy", FunctionDescriptor.ofVoid(ADDRESS)); + ZVEC_VECTOR_QUERY_SET_TOPK = + downcall("zvec_vector_query_set_topk", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT)); + ZVEC_VECTOR_QUERY_SET_FIELD_NAME = + downcall("zvec_vector_query_set_field_name", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + ZVEC_VECTOR_QUERY_SET_QUERY_VECTOR = + downcall( + "zvec_vector_query_set_query_vector", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_LONG)); + ZVEC_VECTOR_QUERY_SET_FILTER = + downcall("zvec_vector_query_set_filter", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + ZVEC_VECTOR_QUERY_SET_INCLUDE_VECTOR = + downcall( + "zvec_vector_query_set_include_vector", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_BOOLEAN)); + ZVEC_VECTOR_QUERY_SET_OUTPUT_FIELDS = + downcall( + "zvec_vector_query_set_output_fields", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_LONG)); + ZVEC_QUERY_PARAMS_HNSW_CREATE = + downcall( + "zvec_query_params_hnsw_create", + FunctionDescriptor.of(ADDRESS, JAVA_INT, JAVA_FLOAT, JAVA_BOOLEAN, JAVA_BOOLEAN)); + ZVEC_QUERY_PARAMS_HNSW_DESTROY = + downcall("zvec_query_params_hnsw_destroy", FunctionDescriptor.ofVoid(ADDRESS)); + ZVEC_VECTOR_QUERY_SET_HNSW_PARAMS = + downcall("zvec_vector_query_set_hnsw_params", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + } + + private FfmNative() {} + + public static String version() { + try { + MemorySegment versionPtr = (MemorySegment) ZVEC_GET_VERSION.invokeExact(); + return readUtf8CString(versionPtr, VERSION_STRING_MAX_BYTES); + } catch (Throwable t) { + throw new IllegalStateException("Failed to read zvec version", t); + } + } + + public static void ensureInitialized() { + try { + if ((boolean) ZVEC_IS_INITIALIZED.invokeExact()) { + return; + } + + synchronized (INIT_LOCK) { + if ((boolean) ZVEC_IS_INITIALIZED.invokeExact()) { + return; + } + + check((int) ZVEC_INITIALIZE.invokeExact(MemorySegment.NULL), "zvec_initialize"); + if (SHUTDOWN_HOOK_REGISTERED.compareAndSet(false, true)) { + Runtime.getRuntime().addShutdownHook(new Thread(FfmNative::shutdownQuietly)); + } + } + } catch (Throwable t) { + if (t instanceof RuntimeException runtimeException) { + throw runtimeException; + } + throw new IllegalStateException("Failed to initialize zvec", t); + } + } + + static void check(int code, String operation) { + if (code == ZVEC_OK) { + return; + } + throw lastError(operation, code); + } + + static ZvecException lastError(String operation, int fallbackCode) { + try (Arena arena = Arena.ofConfined()) { + MemorySegment outMessage = arena.allocate(ADDRESS); + int status = (int) ZVEC_GET_LAST_ERROR.invokeExact(outMessage); + MemorySegment messagePtr = outMessage.get(ADDRESS, 0); + try { + if (status != ZVEC_OK || messagePtr.equals(MemorySegment.NULL)) { + return new ZvecException(fallbackCode, operation + " failed"); + } + return new ZvecException(fallbackCode, operation + ": " + messagePtr.getString(0)); + } finally { + free(messagePtr); + } + } catch (Throwable t) { + return new ZvecException(fallbackCode, operation + " failed"); + } + } + + static MethodHandle handleCollectionCreateAndOpen() { + return ZVEC_COLLECTION_CREATE_AND_OPEN; + } + + static MethodHandle handleCollectionOpen() { + return ZVEC_COLLECTION_OPEN; + } + + static MethodHandle handleCollectionClose() { + return ZVEC_COLLECTION_CLOSE; + } + + static MethodHandle handleCollectionFlush() { + return ZVEC_COLLECTION_FLUSH; + } + + static MethodHandle handleCollectionGetSchema() { + return ZVEC_COLLECTION_GET_SCHEMA; + } + + static MethodHandle handleCollectionInsert() { + return ZVEC_COLLECTION_INSERT; + } + + static MethodHandle handleCollectionQuery() { + return ZVEC_COLLECTION_QUERY; + } + + static MethodHandle handleDocCreate() { + return ZVEC_DOC_CREATE; + } + + static MethodHandle handleDocDestroy() { + return ZVEC_DOC_DESTROY; + } + + static MethodHandle handleDocsFree() { + return ZVEC_DOCS_FREE; + } + + static MethodHandle handleDocSetPk() { + return ZVEC_DOC_SET_PK; + } + + static MethodHandle handleDocSetFieldNull() { + return ZVEC_DOC_SET_FIELD_NULL; + } + + static MethodHandle handleDocAddFieldByValue() { + return ZVEC_DOC_ADD_FIELD_BY_VALUE; + } + + static MethodHandle handleDocGetPkCopy() { + return ZVEC_DOC_GET_PK_COPY; + } + + static MethodHandle handleDocGetScore() { + return ZVEC_DOC_GET_SCORE; + } + + static MethodHandle handleDocGetFieldNames() { + return ZVEC_DOC_GET_FIELD_NAMES; + } + + static MethodHandle handleDocIsFieldNull() { + return ZVEC_DOC_IS_FIELD_NULL; + } + + static MethodHandle handleDocGetFieldValueCopy() { + return ZVEC_DOC_GET_FIELD_VALUE_COPY; + } + + static MethodHandle handleFreeStrArray() { + return ZVEC_FREE_STR_ARRAY; + } + + static MethodHandle handleVectorQueryCreate() { + return ZVEC_VECTOR_QUERY_CREATE; + } + + static MethodHandle handleVectorQueryDestroy() { + return ZVEC_VECTOR_QUERY_DESTROY; + } + + static MethodHandle handleVectorQuerySetTopK() { + return ZVEC_VECTOR_QUERY_SET_TOPK; + } + + static MethodHandle handleVectorQuerySetFieldName() { + return ZVEC_VECTOR_QUERY_SET_FIELD_NAME; + } + + static MethodHandle handleVectorQuerySetQueryVector() { + return ZVEC_VECTOR_QUERY_SET_QUERY_VECTOR; + } + + static MethodHandle handleVectorQuerySetFilter() { + return ZVEC_VECTOR_QUERY_SET_FILTER; + } + + static MethodHandle handleVectorQuerySetIncludeVector() { + return ZVEC_VECTOR_QUERY_SET_INCLUDE_VECTOR; + } + + static MethodHandle handleVectorQuerySetOutputFields() { + return ZVEC_VECTOR_QUERY_SET_OUTPUT_FIELDS; + } + + static MethodHandle handleQueryParamsHnswCreate() { + return ZVEC_QUERY_PARAMS_HNSW_CREATE; + } + + static MethodHandle handleQueryParamsHnswDestroy() { + return ZVEC_QUERY_PARAMS_HNSW_DESTROY; + } + + static MethodHandle handleVectorQuerySetHnswParams() { + return ZVEC_VECTOR_QUERY_SET_HNSW_PARAMS; + } + + static MethodHandle handleCollectionSchemaCreate() { + return ZVEC_COLLECTION_SCHEMA_CREATE; + } + + static MethodHandle handleCollectionSchemaDestroy() { + return ZVEC_COLLECTION_SCHEMA_DESTROY; + } + + static MethodHandle handleCollectionSchemaAddField() { + return ZVEC_COLLECTION_SCHEMA_ADD_FIELD; + } + + static MethodHandle handleCollectionSchemaGetName() { + return ZVEC_COLLECTION_SCHEMA_GET_NAME; + } + + static MethodHandle handleCollectionSchemaGetForwardFields() { + return ZVEC_COLLECTION_SCHEMA_GET_FORWARD_FIELDS; + } + + static MethodHandle handleCollectionSchemaGetVectorFields() { + return ZVEC_COLLECTION_SCHEMA_GET_VECTOR_FIELDS; + } + + static MethodHandle handleFieldSchemaCreate() { + return ZVEC_FIELD_SCHEMA_CREATE; + } + + static MethodHandle handleFieldSchemaDestroy() { + return ZVEC_FIELD_SCHEMA_DESTROY; + } + + static MethodHandle handleFieldSchemaGetName() { + return ZVEC_FIELD_SCHEMA_GET_NAME; + } + + static MethodHandle handleFieldSchemaGetDataType() { + return ZVEC_FIELD_SCHEMA_GET_DATA_TYPE; + } + + static MethodHandle handleFieldSchemaIsNullable() { + return ZVEC_FIELD_SCHEMA_IS_NULLABLE; + } + + static MethodHandle handleFieldSchemaGetDimension() { + return ZVEC_FIELD_SCHEMA_GET_DIMENSION; + } + + static MethodHandle handleFieldSchemaGetIndexParams() { + return ZVEC_FIELD_SCHEMA_GET_INDEX_PARAMS; + } + + static MethodHandle handleFieldSchemaSetIndexParams() { + return ZVEC_FIELD_SCHEMA_SET_INDEX_PARAMS; + } + + static MethodHandle handleIndexParamsCreate() { + return ZVEC_INDEX_PARAMS_CREATE; + } + + static MethodHandle handleIndexParamsDestroy() { + return ZVEC_INDEX_PARAMS_DESTROY; + } + + static MethodHandle handleIndexParamsGetType() { + return ZVEC_INDEX_PARAMS_GET_TYPE; + } + + static MethodHandle handleIndexParamsSetMetricType() { + return ZVEC_INDEX_PARAMS_SET_METRIC_TYPE; + } + + static MethodHandle handleIndexParamsSetHnswParams() { + return ZVEC_INDEX_PARAMS_SET_HNSW_PARAMS; + } + + static MethodHandle handleIndexParamsGetHnswM() { + return ZVEC_INDEX_PARAMS_GET_HNSW_M; + } + + static MethodHandle handleIndexParamsGetHnswEfConstruction() { + return ZVEC_INDEX_PARAMS_GET_HNSW_EF_CONSTRUCTION; + } + + static void free(MemorySegment segment) { + if (segment == null || segment.equals(MemorySegment.NULL)) { + return; + } + + try { + ZVEC_FREE.invokeExact(segment); + } catch (Throwable t) { + if (t instanceof RuntimeException runtimeException) { + throw runtimeException; + } + throw new IllegalStateException("Failed to free native memory", t); + } + } + + static String readUtf8CString(MemorySegment cString) { + return readUtf8CString(cString, Long.MAX_VALUE); + } + + private static String readUtf8CString(MemorySegment cString, long maxBytes) { + return cString.reinterpret(maxBytes).getString(0); + } + + private static MethodHandle downcall(String name, FunctionDescriptor descriptor) { + return LINKER.downcallHandle(LOOKUP.findOrThrow(name), descriptor); + } + + private static void shutdownQuietly() { + try { + if ((boolean) ZVEC_IS_INITIALIZED.invokeExact()) { + ZVEC_SHUTDOWN.invokeExact(); + } + } catch (Throwable ignored) { + } + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeBackend.java b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeBackend.java new file mode 100644 index 000000000..4234aed29 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeBackend.java @@ -0,0 +1,67 @@ +package org.zvec.internal.ffm; + +import java.util.List; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.VectorQuery; +import org.zvec.internal.NativeBackend; +import org.zvec.internal.NativeHandle; +import org.zvec.internal.NativeOpenResult; + +public final class FfmNativeBackend implements NativeBackend { + @Override + public String id() { + return "ffm"; + } + + @Override + public String version() { + return FfmNative.version(); + } + + @Override + public void ensureInitialized() { + FfmNative.ensureInitialized(); + } + + @Override + public NativeOpenResult open(String path) { + ensureInitialized(); + return FfmCollections.open(path); + } + + @Override + public NativeOpenResult createAndOpen(String path, CollectionSchema schema) { + ensureInitialized(); + return FfmCollections.createAndOpen(path, schema); + } + + @Override + public void close(NativeHandle handle) { + FfmCollections.close(handle); + } + + @Override + public void flush(NativeHandle handle) { + FfmCollections.flush(handle); + } + + @Override + public CollectionSchema readSchema(NativeHandle handle) { + return FfmCollections.readSchema(handle); + } + + @Override + public int insert(NativeHandle handle, CollectionSchema schema, List docs) { + return FfmCollections.insert(handle, schema, docs); + } + + @Override + public List query( + NativeHandle handle, + CollectionSchema querySchema, + CollectionSchema resultSchema, + VectorQuery query) { + return FfmCollections.query(handle, querySchema, resultSchema, query); + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeBackendProvider.java b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeBackendProvider.java new file mode 100644 index 000000000..c88ab0fb2 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeBackendProvider.java @@ -0,0 +1,11 @@ +package org.zvec.internal.ffm; + +import org.zvec.internal.NativeBackend; +import org.zvec.internal.NativeBackendProvider; + +public final class FfmNativeBackendProvider implements NativeBackendProvider { + @Override + public NativeBackend create() { + return new FfmNativeBackend(); + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeLoader.java b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeLoader.java new file mode 100644 index 000000000..6b3bf6946 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmNativeLoader.java @@ -0,0 +1,115 @@ +package org.zvec.internal.ffm; + +import org.zvec.internal.ZvecException; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Locale; +import java.util.concurrent.atomic.AtomicBoolean; + +public final class FfmNativeLoader { + private static final AtomicBoolean LOADED = new AtomicBoolean(false); + + private FfmNativeLoader() {} + + static String platformResourcePath(String osName, String arch) { + String platformId = platformId(osName, arch); + return "/META-INF/native/" + platformId + "/" + cApiLibraryName(platformId); + } + + static String cApiLibraryName(String platformId) { + if (platformId.startsWith("windows-")) { + return "zvec_c_api.dll"; + } + if (platformId.startsWith("darwin-")) { + return "libzvec_c_api.dylib"; + } + if (platformId.startsWith("linux-")) { + return "libzvec_c_api.so"; + } + throw new IllegalStateException("Unsupported zvec-java platform: " + platformId); + } + + private static String platformId(String osName, String arch) { + String normalizedOs = osName.toLowerCase(Locale.ROOT); + String normalizedArch = normalizeArch(arch); + if (normalizedOs.contains("mac") || normalizedOs.contains("darwin")) { + return "darwin-" + normalizedArch; + } + if (normalizedOs.contains("linux")) { + return "linux-" + normalizedArch; + } + if (normalizedOs.contains("windows")) { + if (normalizedArch.equals("x86_64")) { + return "windows-x86_64"; + } + throw new IllegalStateException("Unsupported zvec-java platform: " + osName + " / " + arch); + } + throw new IllegalStateException("Unsupported zvec-java platform: " + osName + " / " + arch); + } + + private static String normalizeArch(String arch) { + String normalizedArch = arch.toLowerCase(Locale.ROOT); + if (normalizedArch.equals("aarch64") || normalizedArch.equals("arm64")) { + return "aarch64"; + } + if (normalizedArch.equals("x86_64") || normalizedArch.equals("amd64")) { + return "x86_64"; + } + throw new IllegalStateException("Unsupported zvec-java platform: " + arch); + } + + public static void load() { + if (LOADED.get()) { + return; + } + synchronized (LOADED) { + if (LOADED.get()) { + return; + } + String resource = + platformResourcePath(System.getProperty("os.name"), System.getProperty("os.arch")); + Path extracted = extract(resource); + System.load(extracted.toAbsolutePath().toString()); + LOADED.set(true); + } + } + + private static Path extract(String resourcePath) { + Path targetFile = extractionTarget(resourcePath); + + try { + try (InputStream in = FfmNativeLoader.class.getResourceAsStream(resourcePath)) { + if (in == null) { + throw new IllegalStateException("Missing native resource: " + resourcePath); + } + Files.copy(in, targetFile); + } + targetFile.toFile().deleteOnExit(); + targetFile.getParent().toFile().deleteOnExit(); + return targetFile; + } catch (IOException e) { + throw new IllegalStateException("Failed to extract native library: " + resourcePath, e); + } + } + + static Path extractionTarget(String resourcePath) { + String fileName = resourcePath.substring(resourcePath.lastIndexOf('/') + 1); + String[] segments = resourcePath.split("/"); + if (segments.length < 2) { + throw new IllegalStateException("Unexpected native resource path: " + resourcePath); + } + + String platformId = segments[segments.length - 2]; + try { + Path targetDir = + Files.createTempDirectory( + Path.of(System.getProperty("java.io.tmpdir")), "zvec-java-" + platformId + "-"); + return targetDir.resolve(fileName); + } catch (IOException e) { + throw new IllegalStateException("Failed to create temp directory for native library", e); + } + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmQueries.java b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmQueries.java new file mode 100644 index 000000000..4482787b4 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmQueries.java @@ -0,0 +1,169 @@ +package org.zvec.internal.ffm; + +import org.zvec.internal.ZvecException; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.List; +import java.util.Objects; +import org.zvec.HnswQueryParams; +import org.zvec.VectorSchema; +import org.zvec.VectorQuery; +import org.zvec.internal.HnswDefaults; + +public final class FfmQueries { + private FfmQueries() {} + + public static MemorySegment toNative( + VectorSchema runtimeSchema, VectorSchema publicSchema, VectorQuery query) { + Objects.requireNonNull(runtimeSchema, "runtimeSchema"); + Objects.requireNonNull(publicSchema, "publicSchema"); + Objects.requireNonNull(query, "query"); + + MemorySegment nativeQuery = MemorySegment.NULL; + try (Arena arena = Arena.ofConfined()) { + nativeQuery = (MemorySegment) FfmNative.handleVectorQueryCreate().invokeExact(); + if (nativeQuery.equals(MemorySegment.NULL)) { + throw FfmNative.lastError("zvec_vector_query_create", -1); + } + + FfmNative.check( + (int) + FfmNative.handleVectorQuerySetFieldName() + .invokeExact(nativeQuery, arena.allocateFrom(query.fieldName())), + "zvec_vector_query_set_field_name"); + + float[] vector = query.queryVector(); + MemorySegment vectorBuffer = arena.allocateFrom(JAVA_FLOAT, vector); + FfmNative.check( + (int) + FfmNative.handleVectorQuerySetQueryVector() + .invokeExact(nativeQuery, vectorBuffer, (long) vector.length * Float.BYTES), + "zvec_vector_query_set_query_vector"); + + FfmNative.check( + (int) FfmNative.handleVectorQuerySetTopK().invokeExact(nativeQuery, query.topK()), + "zvec_vector_query_set_topk"); + + FfmNative.check( + (int) + FfmNative.handleVectorQuerySetIncludeVector() + .invokeExact(nativeQuery, query.includeVector()), + "zvec_vector_query_set_include_vector"); + + if (query.filter() != null) { + FfmNative.check( + (int) + FfmNative.handleVectorQuerySetFilter() + .invokeExact(nativeQuery, arena.allocateFrom(query.filter())), + "zvec_vector_query_set_filter"); + } + + if (query.outputFieldsSpecified()) { + List outputFields = query.outputFields(); + if (outputFields.isEmpty()) { + throw new UnsupportedOperationException( + "The current native C API cannot represent an explicit empty output field projection"); + } + MemorySegment fieldsBuffer = arena.allocate(ADDRESS, outputFields.size()); + for (int i = 0; i < outputFields.size(); i++) { + fieldsBuffer.setAtIndex(ADDRESS, i, arena.allocateFrom(outputFields.get(i))); + } + FfmNative.check( + (int) + FfmNative.handleVectorQuerySetOutputFields() + .invokeExact(nativeQuery, fieldsBuffer, (long) outputFields.size()), + "zvec_vector_query_set_output_fields"); + } + + attachHnswParams(nativeQuery, runtimeSchema, publicSchema, query); + + return nativeQuery; + } catch (Throwable t) { + destroyQuietly(nativeQuery); + throw propagate("Failed to convert vector query to native", t); + } + } + + public static void destroy(MemorySegment query) { + if (query == null || query.equals(MemorySegment.NULL)) { + return; + } + try { + FfmNative.handleVectorQueryDestroy().invokeExact(query); + } catch (Throwable t) { + throw propagate("Failed to destroy native query", t); + } + } + + private static void destroyQuietly(MemorySegment query) { + if (query == null || query.equals(MemorySegment.NULL)) { + return; + } + try { + FfmNative.handleVectorQueryDestroy().invokeExact(query); + } catch (Throwable ignored) { + } + } + + private static RuntimeException propagate(String message, Throwable cause) { + if (cause instanceof RuntimeException runtimeException) { + return runtimeException; + } + return new IllegalStateException(message, cause); + } + + private static void attachHnswParams( + MemorySegment nativeQuery, + VectorSchema runtimeSchema, + VectorSchema publicSchema, + VectorQuery query) + throws Throwable { + HnswQueryParams params = resolveAttachedHnswParams(runtimeSchema, publicSchema, query); + if (params == null) { + return; + } + + MemorySegment nativeParams = + (MemorySegment) + FfmNative.handleQueryParamsHnswCreate() + .invokeExact(params.ef(), params.radius(), params.linear(), params.usingRefiner()); + if (nativeParams.equals(MemorySegment.NULL)) { + throw FfmNative.lastError("zvec_query_params_hnsw_create", -1); + } + + boolean handedOff = false; + try { + FfmNative.check( + (int) FfmNative.handleVectorQuerySetHnswParams().invokeExact(nativeQuery, nativeParams), + "zvec_vector_query_set_hnsw_params"); + handedOff = true; + } finally { + if (!handedOff) { + FfmNative.handleQueryParamsHnswDestroy().invokeExact(nativeParams); + } + } + } + + static boolean shouldAttachHnswParams(VectorSchema schema, VectorQuery query) { + Objects.requireNonNull(schema, "schema"); + Objects.requireNonNull(query, "query"); + return schema.hnswIndexParams() != null + || schema.tuningProfile() != null + || schema.expectedDocCount() != null; + } + + static HnswQueryParams resolveAttachedHnswParams( + VectorSchema runtimeSchema, VectorSchema publicSchema, VectorQuery query) { + Objects.requireNonNull(runtimeSchema, "runtimeSchema"); + Objects.requireNonNull(publicSchema, "publicSchema"); + Objects.requireNonNull(query, "query"); + if (!shouldAttachHnswParams(runtimeSchema, query)) { + return null; + } + return HnswDefaults.resolveQueryParams(publicSchema, query); + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmSchemas.java b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmSchemas.java new file mode 100644 index 000000000..8c850ff18 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/main/java/org/zvec/internal/ffm/FfmSchemas.java @@ -0,0 +1,257 @@ +package org.zvec.internal.ffm; + +import org.zvec.internal.ZvecException; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_LONG; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import org.zvec.HnswIndexParams; +import org.zvec.CollectionSchema; +import org.zvec.DataType; +import org.zvec.FieldSchema; +import org.zvec.internal.HnswDefaults; +import org.zvec.VectorSchema; + +public final class FfmSchemas { + private FfmSchemas() {} + + public static MemorySegment toNative(CollectionSchema schema) { + Objects.requireNonNull(schema, "schema"); + + try (Arena arena = Arena.ofConfined()) { + MemorySegment schemaHandle = + (MemorySegment) + FfmNative.handleCollectionSchemaCreate().invokeExact(arena.allocateFrom(schema.name())); + if (schemaHandle.equals(MemorySegment.NULL)) { + throw FfmNative.lastError("zvec_collection_schema_create", -1); + } + + boolean success = false; + try { + for (FieldSchema field : schema.fields()) { + addField(schemaHandle, fieldHandle(field, arena)); + } + for (VectorSchema vector : schema.vectors()) { + addField(schemaHandle, vectorHandle(vector, arena)); + } + success = true; + return schemaHandle; + } finally { + if (!success) { + destroy(schemaHandle); + } + } + } catch (Throwable t) { + throw propagate("Failed to convert collection schema to native", t); + } + } + + public static CollectionSchema fromNative(MemorySegment schemaHandle) { + try (Arena arena = Arena.ofConfined()) { + MemorySegment namePtr = + (MemorySegment) FfmNative.handleCollectionSchemaGetName().invokeExact(schemaHandle); + String name = FfmNative.readUtf8CString(namePtr); + return new CollectionSchema(name, readScalarFields(schemaHandle, arena), readVectorFields(schemaHandle, arena)); + } catch (Throwable t) { + throw propagate("Failed to convert native schema", t); + } + } + + public static void destroy(MemorySegment schemaHandle) { + if (schemaHandle == null || schemaHandle.equals(MemorySegment.NULL)) { + return; + } + + try { + FfmNative.handleCollectionSchemaDestroy().invokeExact(schemaHandle); + } catch (Throwable t) { + throw propagate("Failed to destroy native schema", t); + } + } + + private static void addField(MemorySegment schemaHandle, MemorySegment fieldHandle) throws Throwable { + try { + FfmNative.check( + (int) FfmNative.handleCollectionSchemaAddField().invokeExact(schemaHandle, fieldHandle), + "zvec_collection_schema_add_field"); + } finally { + FfmNative.handleFieldSchemaDestroy().invokeExact(fieldHandle); + } + } + + private static MemorySegment fieldHandle(FieldSchema field, Arena arena) throws Throwable { + return createFieldHandle(field.name(), field.dataType().code(), field.nullable(), 0, arena); + } + + private static MemorySegment vectorHandle(VectorSchema vector, Arena arena) throws Throwable { + MemorySegment fieldHandle = + createFieldHandle(vector.name(), vector.dataType().code(), false, vector.dimension(), arena); + boolean success = false; + try { + applyHnswIndexParams(fieldHandle, HnswDefaults.resolveIndexParams(vector)); + success = true; + return fieldHandle; + } finally { + if (!success) { + FfmNative.handleFieldSchemaDestroy().invokeExact(fieldHandle); + } + } + } + + private static MemorySegment createFieldHandle( + String name, int dataTypeCode, boolean nullable, int dimension, Arena arena) throws Throwable { + MemorySegment fieldHandle = + (MemorySegment) + FfmNative.handleFieldSchemaCreate() + .invokeExact(arena.allocateFrom(name), dataTypeCode, nullable, dimension); + if (fieldHandle.equals(MemorySegment.NULL)) { + throw FfmNative.lastError("zvec_field_schema_create", -1); + } + return fieldHandle; + } + + private static void applyHnswIndexParams(MemorySegment fieldHandle, HnswIndexParams params) throws Throwable { + if (params == null) { + return; + } + + MemorySegment indexParams = + (MemorySegment) + FfmNative.handleIndexParamsCreate().invokeExact(FfmNative.ZVEC_INDEX_TYPE_HNSW); + if (indexParams.equals(MemorySegment.NULL)) { + throw FfmNative.lastError("zvec_index_params_create", -1); + } + + try { + FfmNative.check( + (int) + FfmNative.handleIndexParamsSetMetricType() + .invokeExact(indexParams, FfmNative.ZVEC_METRIC_TYPE_L2), + "zvec_index_params_set_metric_type"); + FfmNative.check( + (int) + FfmNative.handleIndexParamsSetHnswParams() + .invokeExact(indexParams, params.m(), params.efConstruction()), + "zvec_index_params_set_hnsw_params"); + FfmNative.check( + (int) FfmNative.handleFieldSchemaSetIndexParams().invokeExact(fieldHandle, indexParams), + "zvec_field_schema_set_index_params"); + } finally { + FfmNative.handleIndexParamsDestroy().invokeExact(indexParams); + } + } + + private static List readScalarFields(MemorySegment schemaHandle, Arena arena) throws Throwable { + MemorySegment outFields = arena.allocate(ADDRESS); + MemorySegment outCount = arena.allocate(JAVA_LONG); + FfmNative.check( + (int) + FfmNative.handleCollectionSchemaGetForwardFields() + .invokeExact(schemaHandle, outFields, outCount), + "zvec_collection_schema_get_forward_fields"); + return readScalarFieldList(outFields.get(ADDRESS, 0), outCount.get(JAVA_LONG, 0)); + } + + private static List readVectorFields(MemorySegment schemaHandle, Arena arena) throws Throwable { + MemorySegment outFields = arena.allocate(ADDRESS); + MemorySegment outCount = arena.allocate(JAVA_LONG); + FfmNative.check( + (int) + FfmNative.handleCollectionSchemaGetVectorFields() + .invokeExact(schemaHandle, outFields, outCount), + "zvec_collection_schema_get_vector_fields"); + return readVectorFieldList(outFields.get(ADDRESS, 0), outCount.get(JAVA_LONG, 0)); + } + + private static List readScalarFieldList(MemorySegment arrayPtr, long count) throws Throwable { + try { + if (count == 0 || arrayPtr.equals(MemorySegment.NULL)) { + return List.of(); + } + + MemorySegment array = arrayPtr.reinterpret(count * ADDRESS.byteSize()); + List fields = new ArrayList<>(Math.toIntExact(count)); + for (long index = 0; index < count; index++) { + MemorySegment fieldPtr = array.getAtIndex(ADDRESS, index); + String name = + FfmNative.readUtf8CString( + (MemorySegment) FfmNative.handleFieldSchemaGetName().invokeExact(fieldPtr)); + int dataTypeCode = (int) FfmNative.handleFieldSchemaGetDataType().invokeExact(fieldPtr); + boolean nullable = (boolean) FfmNative.handleFieldSchemaIsNullable().invokeExact(fieldPtr); + fields.add(new FieldSchema(name, fromCode(dataTypeCode), nullable)); + } + return fields; + } finally { + FfmNative.free(arrayPtr); + } + } + + private static List readVectorFieldList(MemorySegment arrayPtr, long count) throws Throwable { + try { + if (count == 0 || arrayPtr.equals(MemorySegment.NULL)) { + return List.of(); + } + + MemorySegment array = arrayPtr.reinterpret(count * ADDRESS.byteSize()); + List vectors = new ArrayList<>(Math.toIntExact(count)); + for (long index = 0; index < count; index++) { + MemorySegment fieldPtr = array.getAtIndex(ADDRESS, index); + String name = + FfmNative.readUtf8CString( + (MemorySegment) FfmNative.handleFieldSchemaGetName().invokeExact(fieldPtr)); + int dataTypeCode = (int) FfmNative.handleFieldSchemaGetDataType().invokeExact(fieldPtr); + int dimension = (int) FfmNative.handleFieldSchemaGetDimension().invokeExact(fieldPtr); + vectors.add( + new VectorSchema( + name, + fromCode(dataTypeCode), + dimension, + readHnswIndexParams(fieldPtr), + null, + null)); + } + return vectors; + } finally { + FfmNative.free(arrayPtr); + } + } + + private static HnswIndexParams readHnswIndexParams(MemorySegment fieldPtr) throws Throwable { + MemorySegment indexParams = + (MemorySegment) FfmNative.handleFieldSchemaGetIndexParams().invokeExact(fieldPtr); + if (indexParams.equals(MemorySegment.NULL)) { + return null; + } + + int indexType = (int) FfmNative.handleIndexParamsGetType().invokeExact(indexParams); + if (indexType != FfmNative.ZVEC_INDEX_TYPE_HNSW) { + return null; + } + + int m = (int) FfmNative.handleIndexParamsGetHnswM().invokeExact(indexParams); + int efConstruction = + (int) FfmNative.handleIndexParamsGetHnswEfConstruction().invokeExact(indexParams); + return new HnswIndexParams(m, efConstruction); + } + + private static DataType fromCode(int code) { + for (DataType value : DataType.values()) { + if (value.code() == code) { + return value; + } + } + throw new IllegalArgumentException("Unsupported native data type code: " + code); + } + + private static RuntimeException propagate(String message, Throwable cause) { + if (cause instanceof RuntimeException runtimeException) { + return runtimeException; + } + return new IllegalStateException(message, cause); + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/main/resources/META-INF/services/org.zvec.internal.NativeBackendProvider b/java/zvec-java/zvec-java-ffm/src/main/resources/META-INF/services/org.zvec.internal.NativeBackendProvider new file mode 100644 index 000000000..5a4200e5b --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/main/resources/META-INF/services/org.zvec.internal.NativeBackendProvider @@ -0,0 +1 @@ +org.zvec.internal.ffm.FfmNativeBackendProvider diff --git a/java/zvec-java/zvec-java-jni/pom.xml b/java/zvec-java/zvec-java-jni/pom.xml new file mode 100644 index 000000000..0aea189ed --- /dev/null +++ b/java/zvec-java/zvec-java-jni/pom.xml @@ -0,0 +1,81 @@ + + 4.0.0 + + + org.zvec + zvec-java-parent + 0.0.1-SNAPSHOT + + + zvec-java-jni + zvec-java-jni + + + 11 + + + + + org.zvec + zvec-java-api + ${project.version} + + + org.zvec + zvec-java-api + ${project.version} + test-jar + tests + test + + + org.junit.jupiter + junit-jupiter + test + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + ${maven.compiler.release} + + + + org.codehaus.mojo + exec-maven-plugin + + + build-native-library + process-resources + + exec + + + ${project.basedir}/../../../scripts/build_java_native.sh + + ${project.build.outputDirectory}/META-INF/native + jni + ${zvec.native.platform} + ${project.basedir}/src/main/native + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + false + -Dorg.zvec.backend=jni + + + + + diff --git a/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniHandle.java b/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniHandle.java new file mode 100644 index 000000000..76b4e8b2a --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniHandle.java @@ -0,0 +1,41 @@ +package org.zvec.internal.jni; + +import java.util.Objects; +import org.zvec.internal.NativeHandle; + +final class JniHandle implements NativeHandle { + private final long address; + + JniHandle(long address) { + if (address == 0L) { + throw new IllegalArgumentException("address must not be 0"); + } + this.address = address; + } + + long address() { + return address; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof JniHandle)) { + return false; + } + JniHandle other = (JniHandle) obj; + return address == other.address; + } + + @Override + public int hashCode() { + return Objects.hash(address); + } + + @Override + public String toString() { + return "JniHandle[address=" + address + "]"; + } +} diff --git a/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNative.java b/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNative.java new file mode 100644 index 000000000..76e926130 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNative.java @@ -0,0 +1,36 @@ +package org.zvec.internal.jni; + +import java.util.List; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.VectorQuery; + +final class JniNative { + static { + JniNativeLoader.load(); + } + + private JniNative() {} + + static native String version(); + + static native void ensureInitialized(); + + static native long createAndOpen(String path, CollectionSchema schema); + + static native long open(String path); + + static native void close(long handle); + + static native void flush(long handle); + + static native CollectionSchema readSchema(long handle); + + static native int insert(long handle, CollectionSchema schema, List docs); + + static native List query( + long handle, + CollectionSchema querySchema, + CollectionSchema resultSchema, + VectorQuery query); +} diff --git a/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeBackend.java b/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeBackend.java new file mode 100644 index 000000000..c2f027aca --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeBackend.java @@ -0,0 +1,89 @@ +package org.zvec.internal.jni; + +import java.util.List; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.VectorQuery; +import org.zvec.internal.NativeBackend; +import org.zvec.internal.NativeHandle; +import org.zvec.internal.NativeOpenResult; + +public final class JniNativeBackend implements NativeBackend { + @Override + public String id() { + return "jni"; + } + + @Override + public String version() { + return JniNative.version(); + } + + @Override + public void ensureInitialized() { + JniNative.ensureInitialized(); + } + + @Override + public NativeOpenResult open(String path) { + ensureInitialized(); + long address = JniNative.open(path); + JniHandle handle = new JniHandle(address); + try { + return new NativeOpenResult(handle, JniNative.readSchema(address)); + } catch (RuntimeException e) { + close(handle); + throw e; + } + } + + @Override + public NativeOpenResult createAndOpen(String path, CollectionSchema schema) { + ensureInitialized(); + long address = JniNative.createAndOpen(path, schema); + JniHandle handle = new JniHandle(address); + try { + return new NativeOpenResult(handle, JniNative.readSchema(address)); + } catch (RuntimeException e) { + close(handle); + throw e; + } + } + + @Override + public void close(NativeHandle handle) { + JniNative.close(address(handle)); + } + + @Override + public void flush(NativeHandle handle) { + JniNative.flush(address(handle)); + } + + @Override + public CollectionSchema readSchema(NativeHandle handle) { + return JniNative.readSchema(address(handle)); + } + + @Override + public int insert(NativeHandle handle, CollectionSchema schema, List docs) { + return JniNative.insert(address(handle), schema, docs); + } + + @Override + public List query( + NativeHandle handle, + CollectionSchema querySchema, + CollectionSchema resultSchema, + VectorQuery query) { + return JniNative.query(address(handle), querySchema, resultSchema, query); + } + + private static long address(NativeHandle handle) { + if (handle instanceof JniHandle) { + JniHandle jniHandle = (JniHandle) handle; + return jniHandle.address(); + } + throw new IllegalArgumentException("Handle is not a JNI handle: " + handle.getClass().getName()); + } +} diff --git a/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeBackendProvider.java b/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeBackendProvider.java new file mode 100644 index 000000000..f59272019 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeBackendProvider.java @@ -0,0 +1,11 @@ +package org.zvec.internal.jni; + +import org.zvec.internal.NativeBackend; +import org.zvec.internal.NativeBackendProvider; + +public final class JniNativeBackendProvider implements NativeBackendProvider { + @Override + public NativeBackend create() { + return new JniNativeBackend(); + } +} diff --git a/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeLoader.java b/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeLoader.java new file mode 100644 index 000000000..10bad7a53 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/main/java/org/zvec/internal/jni/JniNativeLoader.java @@ -0,0 +1,123 @@ +package org.zvec.internal.jni; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Locale; +import java.util.concurrent.atomic.AtomicBoolean; + +public final class JniNativeLoader { + private static final AtomicBoolean LOADED = new AtomicBoolean(false); + + private JniNativeLoader() {} + + static String platformResourceDir(String osName, String arch) { + return "/META-INF/native/" + platformId(osName, arch); + } + + static String cApiLibraryName(String platformId) { + if (platformId.startsWith("windows-")) { + return "zvec_c_api.dll"; + } + if (platformId.startsWith("darwin-")) { + return "libzvec_c_api.dylib"; + } + if (platformId.startsWith("linux-")) { + return "libzvec_c_api.so"; + } + throw new IllegalStateException("Unsupported zvec-java-jni platform: " + platformId); + } + + static String jniLibraryName(String platformId) { + if (platformId.startsWith("windows-")) { + return "zvec_java_jni.dll"; + } + if (platformId.startsWith("darwin-")) { + return "libzvec_java_jni.dylib"; + } + if (platformId.startsWith("linux-")) { + return "libzvec_java_jni.so"; + } + throw new IllegalStateException("Unsupported zvec-java-jni platform: " + platformId); + } + + private static String platformId(String osName, String arch) { + String normalizedOs = osName.toLowerCase(Locale.ROOT); + String normalizedArch = normalizeArch(arch); + if (normalizedOs.contains("mac") || normalizedOs.contains("darwin")) { + return "darwin-" + normalizedArch; + } + if (normalizedOs.contains("linux")) { + return "linux-" + normalizedArch; + } + if (normalizedOs.contains("windows")) { + if (normalizedArch.equals("x86_64")) { + return "windows-x86_64"; + } + throw new IllegalStateException("Unsupported zvec-java-jni platform: " + osName + " / " + arch); + } + throw new IllegalStateException("Unsupported zvec-java-jni platform: " + osName + " / " + arch); + } + + private static String normalizeArch(String arch) { + String normalizedArch = arch.toLowerCase(Locale.ROOT); + if (normalizedArch.equals("aarch64") || normalizedArch.equals("arm64")) { + return "aarch64"; + } + if (normalizedArch.equals("x86_64") || normalizedArch.equals("amd64")) { + return "x86_64"; + } + throw new IllegalStateException("Unsupported zvec-java-jni platform: " + arch); + } + + public static void load() { + if (LOADED.get()) { + return; + } + synchronized (LOADED) { + if (LOADED.get()) { + return; + } + String resourceDir = + platformResourceDir(System.getProperty("os.name"), System.getProperty("os.arch")); + String platformId = resourceDir.substring(resourceDir.lastIndexOf('/') + 1); + Path targetDir = extractionTarget(resourceDir); + Path cApi = extract(resourceDir + "/" + cApiLibraryName(platformId), targetDir); + Path jni = extract(resourceDir + "/" + jniLibraryName(platformId), targetDir); + System.load(cApi.toAbsolutePath().toString()); + System.load(jni.toAbsolutePath().toString()); + LOADED.set(true); + } + } + + static Path extractionTarget(String resourceDir) { + String platformId = resourceDir.substring(resourceDir.lastIndexOf('/') + 1); + try { + Path targetDir = + Files.createTempDirectory( + Path.of(System.getProperty("java.io.tmpdir")), "zvec-java-" + platformId + "-"); + targetDir.toFile().deleteOnExit(); + return targetDir; + } catch (IOException e) { + throw new IllegalStateException("Failed to create temp directory for native libraries", e); + } + } + + private static Path extract(String resourcePath, Path targetDir) { + Path targetFile = targetDir.resolve(resourcePath.substring(resourcePath.lastIndexOf('/') + 1)); + + try { + try (InputStream in = JniNativeLoader.class.getResourceAsStream(resourcePath)) { + if (in == null) { + throw new IllegalStateException("Missing native resource: " + resourcePath); + } + Files.copy(in, targetFile); + } + targetFile.toFile().deleteOnExit(); + return targetFile; + } catch (IOException e) { + throw new IllegalStateException("Failed to extract native library: " + resourcePath, e); + } + } +} diff --git a/java/zvec-java/zvec-java-jni/src/main/native/zvec_java_jni.cc b/java/zvec-java/zvec-java-jni/src/main/native/zvec_java_jni.cc new file mode 100644 index 000000000..74db22954 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/main/native/zvec_java_jni.cc @@ -0,0 +1,991 @@ +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +constexpr int kIndexTypeHnsw = 1; +constexpr int kMetricTypeL2 = 1; +constexpr int kTypeString = 2; +constexpr int kTypeBool = 3; +constexpr int kTypeInt64 = 5; +constexpr int kTypeDouble = 9; +constexpr int kTypeVectorFp32 = 23; + +struct SchemaDeleter { + void operator()(zvec_collection_schema_t *p) const { + zvec_collection_schema_destroy(p); + } +}; +struct FieldDeleter { + void operator()(zvec_field_schema_t *p) const { zvec_field_schema_destroy(p); } +}; +struct IndexDeleter { + void operator()(zvec_index_params_t *p) const { zvec_index_params_destroy(p); } +}; +struct DocDeleter { + void operator()(zvec_doc_t *p) const { zvec_doc_destroy(p); } +}; +struct QueryDeleter { + void operator()(zvec_vector_query_t *p) const { zvec_vector_query_destroy(p); } +}; + +using SchemaPtr = std::unique_ptr; +using FieldPtr = std::unique_ptr; +using IndexPtr = std::unique_ptr; +using DocPtr = std::unique_ptr; +using QueryPtr = std::unique_ptr; + +std::string to_string(JNIEnv *env, jstring value) { + if (value == nullptr) { + return ""; + } + const char *chars = env->GetStringUTFChars(value, nullptr); + if (chars == nullptr) { + return ""; + } + std::string out(chars); + env->ReleaseStringUTFChars(value, chars); + return out; +} + +jstring to_jstring(JNIEnv *env, const std::string &value) { + return env->NewStringUTF(value.c_str()); +} + +jclass find_class(JNIEnv *env, const char *name) { + jclass cls = env->FindClass(name); + if (cls == nullptr) { + throw std::runtime_error(std::string("Missing Java class: ") + name); + } + return cls; +} + +jmethodID method(JNIEnv *env, jclass cls, const char *name, const char *sig) { + jmethodID mid = env->GetMethodID(cls, name, sig); + if (mid == nullptr) { + throw std::runtime_error(std::string("Missing Java method: ") + name + sig); + } + return mid; +} + +jmethodID static_method(JNIEnv *env, jclass cls, const char *name, + const char *sig) { + jmethodID mid = env->GetStaticMethodID(cls, name, sig); + if (mid == nullptr) { + throw std::runtime_error(std::string("Missing Java static method: ") + name + + sig); + } + return mid; +} + +jfieldID static_field(JNIEnv *env, jclass cls, const char *name, + const char *sig) { + jfieldID fid = env->GetStaticFieldID(cls, name, sig); + if (fid == nullptr) { + throw std::runtime_error(std::string("Missing Java static field: ") + name); + } + return fid; +} + +void throw_exception(JNIEnv *env, const char *class_name, + const std::string &message) { + if (env->ExceptionCheck()) { + return; + } + jclass cls = env->FindClass(class_name); + if (cls != nullptr) { + env->ThrowNew(cls, message.c_str()); + } +} + +void throw_zvec(JNIEnv *env, int code, const std::string &message) { + if (env->ExceptionCheck()) { + return; + } + jclass cls = env->FindClass("org/zvec/internal/ZvecException"); + if (cls == nullptr) { + throw_exception(env, "java/lang/IllegalStateException", message); + return; + } + jmethodID ctor = env->GetMethodID(cls, "", "(ILjava/lang/String;)V"); + if (ctor == nullptr) { + throw_exception(env, "java/lang/IllegalStateException", message); + return; + } + jstring msg = env->NewStringUTF(message.c_str()); + jobject ex = env->NewObject(cls, ctor, static_cast(code), msg); + if (ex != nullptr) { + env->Throw(static_cast(ex)); + } +} + +std::string last_error_message(const char *operation) { + char *message = nullptr; + zvec_get_last_error(&message); + std::string out(operation); + if (message != nullptr) { + out += ": "; + out += message; + zvec_free(message); + } else { + out += " failed"; + } + return out; +} + +bool check(JNIEnv *env, zvec_error_code_t code, const char *operation) { + if (code == ZVEC_OK) { + return true; + } + throw_zvec(env, static_cast(code), last_error_message(operation)); + return false; +} + +int data_type_code(JNIEnv *env, jobject data_type) { + jclass cls = find_class(env, "org/zvec/DataType"); + return env->CallIntMethod(data_type, method(env, cls, "code", "()I")); +} + +jobject data_type_for_code(JNIEnv *env, int code) { + jclass cls = find_class(env, "org/zvec/DataType"); + const char *name = nullptr; + switch (code) { + case kTypeString: + name = "STRING"; + break; + case kTypeBool: + name = "BOOL"; + break; + case kTypeInt64: + name = "INT64"; + break; + case kTypeDouble: + name = "DOUBLE"; + break; + case kTypeVectorFp32: + name = "VECTOR_FP32"; + break; + default: + throw std::runtime_error("Unsupported native data type code: " + + std::to_string(code)); + } + return env->GetStaticObjectField( + cls, static_field(env, cls, name, "Lorg/zvec/DataType;")); +} + +int list_size(JNIEnv *env, jobject list) { + jclass cls = find_class(env, "java/util/List"); + return env->CallIntMethod(list, method(env, cls, "size", "()I")); +} + +jobject list_get(JNIEnv *env, jobject list, int index) { + jclass cls = find_class(env, "java/util/List"); + return env->CallObjectMethod(list, + method(env, cls, "get", "(I)Ljava/lang/Object;"), + static_cast(index)); +} + +void list_add(JNIEnv *env, jobject list, jobject value) { + jclass cls = find_class(env, "java/util/List"); + env->CallBooleanMethod(list, + method(env, cls, "add", "(Ljava/lang/Object;)Z"), + value); +} + +jobject new_array_list(JNIEnv *env) { + jclass cls = find_class(env, "java/util/ArrayList"); + return env->NewObject(cls, method(env, cls, "", "()V")); +} + +class Iterator { + public: + Iterator(JNIEnv *env, jobject iterable) : env_(env) { + jobject iterator_obj = env_->CallObjectMethod( + iterable, method(env_, find_class(env_, "java/lang/Iterable"), + "iterator", "()Ljava/util/Iterator;")); + iterator_ = iterator_obj; + iterator_class_ = find_class(env_, "java/util/Iterator"); + } + + bool has_next() { + return env_->CallBooleanMethod( + iterator_, method(env_, iterator_class_, "hasNext", "()Z")); + } + + jobject next() { + return env_->CallObjectMethod( + iterator_, method(env_, iterator_class_, "next", "()Ljava/lang/Object;")); + } + + private: + JNIEnv *env_; + jobject iterator_; + jclass iterator_class_; +}; + +jobject map_entry_set(JNIEnv *env, jobject map) { + return env->CallObjectMethod( + map, method(env, find_class(env, "java/util/Map"), "entrySet", + "()Ljava/util/Set;")); +} + +jobject entry_key(JNIEnv *env, jobject entry) { + return env->CallObjectMethod( + entry, method(env, find_class(env, "java/util/Map$Entry"), "getKey", + "()Ljava/lang/Object;")); +} + +jobject entry_value(JNIEnv *env, jobject entry) { + return env->CallObjectMethod( + entry, method(env, find_class(env, "java/util/Map$Entry"), "getValue", + "()Ljava/lang/Object;")); +} + +void apply_hnsw_index(JNIEnv *env, zvec_field_schema_t *field, + jobject vector_schema) { + jclass defaults_cls = find_class(env, "org/zvec/internal/HnswDefaults"); + jobject params = env->CallStaticObjectMethod( + defaults_cls, + static_method(env, defaults_cls, "resolveIndexParams", + "(Lorg/zvec/VectorSchema;)Lorg/zvec/HnswIndexParams;"), + vector_schema); + if (env->ExceptionCheck() || params == nullptr) { + return; + } + + jclass params_cls = find_class(env, "org/zvec/HnswIndexParams"); + int m = env->CallIntMethod(params, method(env, params_cls, "m", "()I")); + int ef = env->CallIntMethod( + params, method(env, params_cls, "efConstruction", "()I")); + + IndexPtr index(zvec_index_params_create(kIndexTypeHnsw)); + if (!index) { + throw std::runtime_error(last_error_message("zvec_index_params_create")); + } + if (!check(env, + zvec_index_params_set_metric_type(index.get(), kMetricTypeL2), + "zvec_index_params_set_metric_type")) { + return; + } + if (!check(env, zvec_index_params_set_hnsw_params(index.get(), m, ef), + "zvec_index_params_set_hnsw_params")) { + return; + } + check(env, zvec_field_schema_set_index_params(field, index.get()), + "zvec_field_schema_set_index_params"); +} + +SchemaPtr schema_to_native(JNIEnv *env, jobject schema) { + jclass schema_cls = find_class(env, "org/zvec/CollectionSchema"); + std::string name = + to_string(env, static_cast(env->CallObjectMethod( + schema, method(env, schema_cls, "name", + "()Ljava/lang/String;")))); + SchemaPtr native_schema(zvec_collection_schema_create(name.c_str())); + if (!native_schema) { + throw std::runtime_error(last_error_message("zvec_collection_schema_create")); + } + + jobject fields = env->CallObjectMethod( + schema, method(env, schema_cls, "fields", "()Ljava/util/List;")); + jclass field_cls = find_class(env, "org/zvec/FieldSchema"); + for (int i = 0; i < list_size(env, fields); i++) { + jobject field = list_get(env, fields, i); + std::string field_name = + to_string(env, static_cast(env->CallObjectMethod( + field, method(env, field_cls, "name", + "()Ljava/lang/String;")))); + jobject data_type = env->CallObjectMethod( + field, method(env, field_cls, "dataType", "()Lorg/zvec/DataType;")); + jboolean nullable = + env->CallBooleanMethod(field, method(env, field_cls, "nullable", "()Z")); + FieldPtr native_field(zvec_field_schema_create( + field_name.c_str(), data_type_code(env, data_type), nullable, 0)); + if (!native_field) { + throw std::runtime_error(last_error_message("zvec_field_schema_create")); + } + if (!check(env, + zvec_collection_schema_add_field(native_schema.get(), + native_field.get()), + "zvec_collection_schema_add_field")) { + return nullptr; + } + } + + jobject vectors = env->CallObjectMethod( + schema, method(env, schema_cls, "vectors", "()Ljava/util/List;")); + jclass vector_cls = find_class(env, "org/zvec/VectorSchema"); + for (int i = 0; i < list_size(env, vectors); i++) { + jobject vector = list_get(env, vectors, i); + std::string vector_name = + to_string(env, static_cast(env->CallObjectMethod( + vector, method(env, vector_cls, "name", + "()Ljava/lang/String;")))); + jobject data_type = env->CallObjectMethod( + vector, + method(env, vector_cls, "dataType", "()Lorg/zvec/DataType;")); + jint dimension = + env->CallIntMethod(vector, method(env, vector_cls, "dimension", "()I")); + FieldPtr native_field(zvec_field_schema_create( + vector_name.c_str(), data_type_code(env, data_type), false, + static_cast(dimension))); + if (!native_field) { + throw std::runtime_error(last_error_message("zvec_field_schema_create")); + } + apply_hnsw_index(env, native_field.get(), vector); + if (env->ExceptionCheck()) { + return nullptr; + } + if (!check(env, + zvec_collection_schema_add_field(native_schema.get(), + native_field.get()), + "zvec_collection_schema_add_field")) { + return nullptr; + } + } + + return native_schema; +} + +jobject hnsw_index_params_to_java(JNIEnv *env, const zvec_field_schema_t *field) { + const zvec_index_params_t *params = zvec_field_schema_get_index_params(field); + if (params == nullptr || zvec_index_params_get_type(params) != kIndexTypeHnsw) { + return nullptr; + } + jclass cls = find_class(env, "org/zvec/HnswIndexParams"); + return env->NewObject(cls, method(env, cls, "", "(II)V"), + zvec_index_params_get_hnsw_m(params), + zvec_index_params_get_hnsw_ef_construction(params)); +} + +jobject schema_from_native(JNIEnv *env, zvec_collection_schema_t *schema) { + jclass field_cls = find_class(env, "org/zvec/FieldSchema"); + jclass vector_cls = find_class(env, "org/zvec/VectorSchema"); + jclass schema_cls = find_class(env, "org/zvec/CollectionSchema"); + + jobject fields = new_array_list(env); + zvec_field_schema_t **field_array = nullptr; + size_t field_count = 0; + if (!check(env, + zvec_collection_schema_get_forward_fields(schema, &field_array, + &field_count), + "zvec_collection_schema_get_forward_fields")) { + return nullptr; + } + for (size_t i = 0; i < field_count; i++) { + zvec_field_schema_t *field = field_array[i]; + jobject data_type = + data_type_for_code(env, zvec_field_schema_get_data_type(field)); + jobject java_field = env->NewObject( + field_cls, + method(env, field_cls, "", + "(Ljava/lang/String;Lorg/zvec/DataType;Z)V"), + to_jstring(env, zvec_field_schema_get_name(field)), data_type, + zvec_field_schema_is_nullable(field)); + list_add(env, fields, java_field); + } + zvec_free(field_array); + + jobject vectors = new_array_list(env); + zvec_field_schema_t **vector_array = nullptr; + size_t vector_count = 0; + if (!check(env, + zvec_collection_schema_get_vector_fields(schema, &vector_array, + &vector_count), + "zvec_collection_schema_get_vector_fields")) { + return nullptr; + } + for (size_t i = 0; i < vector_count; i++) { + zvec_field_schema_t *vector = vector_array[i]; + jobject data_type = + data_type_for_code(env, zvec_field_schema_get_data_type(vector)); + jobject hnsw = hnsw_index_params_to_java(env, vector); + jobject java_vector = env->NewObject( + vector_cls, + method(env, vector_cls, "", + "(Ljava/lang/String;Lorg/zvec/DataType;ILorg/zvec/" + "HnswIndexParams;Lorg/zvec/TuningProfile;Ljava/lang/Long;)V"), + to_jstring(env, zvec_field_schema_get_name(vector)), data_type, + static_cast(zvec_field_schema_get_dimension(vector)), hnsw, + nullptr, nullptr); + list_add(env, vectors, java_vector); + } + zvec_free(vector_array); + + return env->NewObject( + schema_cls, + method(env, schema_cls, "", + "(Ljava/lang/String;Ljava/util/List;Ljava/util/List;)V"), + to_jstring(env, zvec_collection_schema_get_name(schema)), fields, vectors); +} + +zvec_collection_t *handle(jlong address) { + if (address == 0) { + throw std::invalid_argument("Native collection handle is 0"); + } + return reinterpret_cast(address); +} + +int field_data_type(JNIEnv *env, jobject schema, const std::string &name) { + jclass schema_cls = find_class(env, "org/zvec/CollectionSchema"); + jobject field = env->CallObjectMethod( + schema, + method(env, schema_cls, "field", + "(Ljava/lang/String;)Lorg/zvec/FieldSchema;"), + to_jstring(env, name)); + if (field != nullptr) { + return data_type_code(env, env->CallObjectMethod( + field, + method(env, find_class(env, "org/zvec/FieldSchema"), + "dataType", "()Lorg/zvec/DataType;"))); + } + jobject vector = env->CallObjectMethod( + schema, + method(env, schema_cls, "vector", + "(Ljava/lang/String;)Lorg/zvec/VectorSchema;"), + to_jstring(env, name)); + if (vector != nullptr) { + return data_type_code(env, env->CallObjectMethod( + vector, + method(env, find_class(env, "org/zvec/VectorSchema"), + "dataType", "()Lorg/zvec/DataType;"))); + } + throw std::invalid_argument("Unknown field: " + name); +} + +void add_doc_field(JNIEnv *env, zvec_doc_t *doc, const std::string &name, + int type, jobject value) { + if (type == kTypeString) { + std::string string_value = to_string(env, static_cast(value)); + check(env, zvec_doc_add_field_by_value(doc, name.c_str(), type, + string_value.data(), + string_value.size()), + "zvec_doc_add_field_by_value"); + return; + } + if (type == kTypeBool) { + bool bool_value = env->CallBooleanMethod( + value, method(env, find_class(env, "java/lang/Boolean"), + "booleanValue", "()Z")); + check(env, zvec_doc_add_field_by_value(doc, name.c_str(), type, &bool_value, + sizeof(bool_value)), + "zvec_doc_add_field_by_value"); + return; + } + if (type == kTypeInt64) { + int64_t long_value = env->CallLongMethod( + value, + method(env, find_class(env, "java/lang/Long"), "longValue", "()J")); + check(env, zvec_doc_add_field_by_value(doc, name.c_str(), type, &long_value, + sizeof(long_value)), + "zvec_doc_add_field_by_value"); + return; + } + if (type == kTypeDouble) { + double double_value = env->CallDoubleMethod( + value, method(env, find_class(env, "java/lang/Double"), "doubleValue", + "()D")); + check(env, zvec_doc_add_field_by_value(doc, name.c_str(), type, + &double_value, sizeof(double_value)), + "zvec_doc_add_field_by_value"); + return; + } + throw std::invalid_argument("Unsupported scalar field type: " + + std::to_string(type)); +} + +jobject vector_schema(JNIEnv *env, jobject schema, const std::string &name); +int vector_dimension(JNIEnv *env, jobject schema, const std::string &name); + +DocPtr doc_to_native(JNIEnv *env, jobject doc, jobject schema) { + jclass doc_cls = find_class(env, "org/zvec/Doc"); + DocPtr native_doc(zvec_doc_create()); + if (!native_doc) { + throw std::runtime_error(last_error_message("zvec_doc_create")); + } + + std::string id = + to_string(env, static_cast(env->CallObjectMethod( + doc, method(env, doc_cls, "id", "()Ljava/lang/String;")))); + zvec_doc_set_pk(native_doc.get(), id.c_str()); + + jobject fields = env->CallObjectMethod( + doc, method(env, doc_cls, "fields", "()Ljava/util/Map;")); + for (Iterator it(env, map_entry_set(env, fields)); it.has_next();) { + jobject entry = it.next(); + std::string name = to_string(env, static_cast(entry_key(env, entry))); + add_doc_field(env, native_doc.get(), name, field_data_type(env, schema, name), + entry_value(env, entry)); + if (env->ExceptionCheck()) { + return nullptr; + } + } + + jobject null_fields = env->CallObjectMethod( + doc, method(env, doc_cls, "nullFields", "()Ljava/util/Set;")); + for (Iterator it(env, null_fields); it.has_next();) { + std::string name = to_string(env, static_cast(it.next())); + if (!check(env, zvec_doc_set_field_null(native_doc.get(), name.c_str()), + "zvec_doc_set_field_null")) { + return nullptr; + } + } + + jobject vectors = env->CallObjectMethod( + doc, method(env, doc_cls, "vectors", "()Ljava/util/Map;")); + for (Iterator it(env, map_entry_set(env, vectors)); it.has_next();) { + jobject entry = it.next(); + std::string name = to_string(env, static_cast(entry_key(env, entry))); + jfloatArray vector_array = static_cast(entry_value(env, entry)); + jsize length = env->GetArrayLength(vector_array); + int expected_dimension = vector_dimension(env, schema, name); + if (length != expected_dimension) { + throw std::invalid_argument( + "Vector dimension mismatch for field " + name + ": expected " + + std::to_string(expected_dimension) + ", got " + std::to_string(length)); + } + std::vector values(length); + env->GetFloatArrayRegion(vector_array, 0, length, values.data()); + if (!check(env, + zvec_doc_add_field_by_value(native_doc.get(), name.c_str(), + kTypeVectorFp32, values.data(), + values.size() * sizeof(float)), + "zvec_doc_add_field_by_value")) { + return nullptr; + } + } + + return native_doc; +} + +std::vector docs_to_native(JNIEnv *env, jobject docs, jobject schema) { + std::vector out; + int count = list_size(env, docs); + out.reserve(count); + for (int i = 0; i < count; i++) { + out.push_back(doc_to_native(env, list_get(env, docs, i), schema)); + if (env->ExceptionCheck()) { + break; + } + } + return out; +} + +jobject vector_schema(JNIEnv *env, jobject schema, const std::string &name) { + return env->CallObjectMethod( + schema, + method(env, find_class(env, "org/zvec/CollectionSchema"), "vector", + "(Ljava/lang/String;)Lorg/zvec/VectorSchema;"), + to_jstring(env, name)); +} + +int vector_dimension(JNIEnv *env, jobject schema, const std::string &name) { + jobject vector = vector_schema(env, schema, name); + if (vector == nullptr) { + throw std::invalid_argument("Unknown vector field: " + name); + } + return env->CallIntMethod( + vector, method(env, find_class(env, "org/zvec/VectorSchema"), + "dimension", "()I")); +} + +void attach_hnsw_query(JNIEnv *env, zvec_vector_query_t *native_query, + jobject public_vector_schema, jobject query) { + jclass defaults_cls = find_class(env, "org/zvec/internal/HnswDefaults"); + jobject params = env->CallStaticObjectMethod( + defaults_cls, + static_method(env, defaults_cls, "resolveQueryParams", + "(Lorg/zvec/VectorSchema;Lorg/zvec/VectorQuery;)" + "Lorg/zvec/HnswQueryParams;"), + public_vector_schema, query); + if (env->ExceptionCheck() || params == nullptr) { + return; + } + jclass params_cls = find_class(env, "org/zvec/HnswQueryParams"); + int ef = env->CallIntMethod(params, method(env, params_cls, "ef", "()I")); + float radius = + env->CallFloatMethod(params, method(env, params_cls, "radius", "()F")); + jboolean linear = + env->CallBooleanMethod(params, method(env, params_cls, "linear", "()Z")); + jboolean refiner = env->CallBooleanMethod( + params, method(env, params_cls, "usingRefiner", "()Z")); + + zvec_hnsw_query_params_t *raw = + zvec_query_params_hnsw_create(ef, radius, linear, refiner); + if (raw == nullptr) { + throw std::runtime_error(last_error_message("zvec_query_params_hnsw_create")); + } + if (check(env, zvec_vector_query_set_hnsw_params(native_query, raw), + "zvec_vector_query_set_hnsw_params")) { + raw = nullptr; + } + if (raw != nullptr) { + zvec_query_params_hnsw_destroy(raw); + } +} + +QueryPtr query_to_native(JNIEnv *env, jobject query, jobject query_schema, + jobject result_schema) { + jclass query_cls = find_class(env, "org/zvec/VectorQuery"); + QueryPtr native_query(zvec_vector_query_create()); + if (!native_query) { + throw std::runtime_error(last_error_message("zvec_vector_query_create")); + } + + std::string field_name = + to_string(env, static_cast(env->CallObjectMethod( + query, method(env, query_cls, "fieldName", + "()Ljava/lang/String;")))); + if (!check(env, + zvec_vector_query_set_field_name(native_query.get(), + field_name.c_str()), + "zvec_vector_query_set_field_name")) { + return nullptr; + } + + jfloatArray query_vector = static_cast(env->CallObjectMethod( + query, method(env, query_cls, "queryVector", "()[F"))); + jsize length = env->GetArrayLength(query_vector); + std::vector values(length); + env->GetFloatArrayRegion(query_vector, 0, length, values.data()); + if (!check(env, + zvec_vector_query_set_query_vector(native_query.get(), + values.data(), + values.size() * sizeof(float)), + "zvec_vector_query_set_query_vector")) { + return nullptr; + } + + if (!check(env, + zvec_vector_query_set_topk( + native_query.get(), + env->CallIntMethod(query, method(env, query_cls, "topK", "()I"))), + "zvec_vector_query_set_topk")) { + return nullptr; + } + + if (!check(env, + zvec_vector_query_set_include_vector( + native_query.get(), + env->CallBooleanMethod(query, + method(env, query_cls, "includeVector", + "()Z"))), + "zvec_vector_query_set_include_vector")) { + return nullptr; + } + + jstring filter = static_cast(env->CallObjectMethod( + query, method(env, query_cls, "filter", "()Ljava/lang/String;"))); + std::string filter_value; + if (filter != nullptr) { + filter_value = to_string(env, filter); + if (!check(env, + zvec_vector_query_set_filter(native_query.get(), + filter_value.c_str()), + "zvec_vector_query_set_filter")) { + return nullptr; + } + } + + if (env->CallBooleanMethod( + query, method(env, query_cls, "outputFieldsSpecified", "()Z"))) { + jobject output_fields = env->CallObjectMethod( + query, method(env, query_cls, "outputFields", "()Ljava/util/List;")); + int count = list_size(env, output_fields); + if (count == 0) { + throw std::runtime_error( + "The current native C API cannot represent an explicit empty output field projection"); + } + std::vector names; + std::vector ptrs; + names.reserve(count); + ptrs.reserve(count); + for (int i = 0; i < count; i++) { + names.push_back(to_string(env, static_cast(list_get(env, output_fields, i)))); + } + for (const std::string &name : names) { + ptrs.push_back(name.c_str()); + } + if (!check(env, + zvec_vector_query_set_output_fields(native_query.get(), + ptrs.data(), ptrs.size()), + "zvec_vector_query_set_output_fields")) { + return nullptr; + } + } + + jobject runtime_vector = vector_schema(env, query_schema, field_name); + jobject public_vector = vector_schema(env, result_schema, field_name); + if (runtime_vector == nullptr || public_vector == nullptr) { + throw std::invalid_argument("Unknown vector field: " + field_name); + } + attach_hnsw_query(env, native_query.get(), public_vector, query); + return native_query; +} + +jobject doc_from_native(JNIEnv *env, zvec_doc_t *doc, jobject schema) { + jclass doc_cls = find_class(env, "org/zvec/Doc"); + char *pk = const_cast(zvec_doc_get_pk_copy(doc)); + std::string id = pk == nullptr ? "" : pk; + zvec_free(pk); + + jobject result = env->CallStaticObjectMethod( + doc_cls, + static_method(env, doc_cls, "result", + "(Ljava/lang/String;D)Lorg/zvec/Doc;"), + to_jstring(env, id), static_cast(zvec_doc_get_score(doc))); + + char **names = nullptr; + size_t count = 0; + if (!check(env, zvec_doc_get_field_names(doc, &names, &count), + "zvec_doc_get_field_names")) { + return nullptr; + } + + for (size_t i = 0; i < count; i++) { + std::string name = names[i]; + int type = field_data_type(env, schema, name); + if (type != kTypeVectorFp32 && zvec_doc_is_field_null(doc, name.c_str())) { + env->CallObjectMethod( + result, method(env, doc_cls, "nullField", + "(Ljava/lang/String;)Lorg/zvec/Doc;"), + to_jstring(env, name)); + continue; + } + + void *value = nullptr; + size_t size = 0; + if (!check(env, + zvec_doc_get_field_value_copy(doc, name.c_str(), type, &value, + &size), + "zvec_doc_get_field_value_copy")) { + zvec_free_str_array(names, count); + return nullptr; + } + if (type == kTypeString) { + std::string text(static_cast(value), size); + env->CallObjectMethod( + result, + method(env, doc_cls, "field", + "(Ljava/lang/String;Ljava/lang/String;)Lorg/zvec/Doc;"), + to_jstring(env, name), to_jstring(env, text)); + } else if (type == kTypeBool) { + env->CallObjectMethod( + result, method(env, doc_cls, "field", + "(Ljava/lang/String;Z)Lorg/zvec/Doc;"), + to_jstring(env, name), *static_cast(value)); + } else if (type == kTypeInt64) { + env->CallObjectMethod( + result, method(env, doc_cls, "field", + "(Ljava/lang/String;J)Lorg/zvec/Doc;"), + to_jstring(env, name), *static_cast(value)); + } else if (type == kTypeDouble) { + env->CallObjectMethod( + result, method(env, doc_cls, "field", + "(Ljava/lang/String;D)Lorg/zvec/Doc;"), + to_jstring(env, name), *static_cast(value)); + } else if (type == kTypeVectorFp32) { + jsize length = static_cast(size / sizeof(float)); + jfloatArray vector = env->NewFloatArray(length); + env->SetFloatArrayRegion(vector, 0, length, + static_cast(value)); + env->CallObjectMethod( + result, + method(env, doc_cls, "vector", + "(Ljava/lang/String;[F)Lorg/zvec/Doc;"), + to_jstring(env, name), vector); + } + zvec_free(value); + } + zvec_free_str_array(names, count); + return result; +} + +} // namespace + +extern "C" JNIEXPORT jstring JNICALL +Java_org_zvec_internal_jni_JniNative_version(JNIEnv *env, jclass) { + return to_jstring(env, zvec_get_version()); +} + +extern "C" JNIEXPORT void JNICALL +Java_org_zvec_internal_jni_JniNative_ensureInitialized(JNIEnv *env, jclass) { + if (zvec_is_initialized()) { + return; + } + check(env, zvec_initialize(nullptr), "zvec_initialize"); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_zvec_internal_jni_JniNative_createAndOpen(JNIEnv *env, jclass, + jstring path, + jobject schema) { + try { + SchemaPtr native_schema = schema_to_native(env, schema); + if (env->ExceptionCheck()) { + return 0; + } + zvec_collection_t *collection = nullptr; + std::string path_string = to_string(env, path); + if (!check(env, + zvec_collection_create_and_open(path_string.c_str(), + native_schema.get(), nullptr, + &collection), + "zvec_collection_create_and_open")) { + return 0; + } + return reinterpret_cast(collection); + } catch (const std::invalid_argument &e) { + throw_exception(env, "java/lang/IllegalArgumentException", e.what()); + return 0; + } catch (const std::exception &e) { + throw_exception(env, "java/lang/IllegalStateException", e.what()); + return 0; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_zvec_internal_jni_JniNative_open(JNIEnv *env, jclass, jstring path) { + try { + zvec_collection_t *collection = nullptr; + std::string path_string = to_string(env, path); + if (!check(env, zvec_collection_open(path_string.c_str(), nullptr, + &collection), + "zvec_collection_open")) { + return 0; + } + return reinterpret_cast(collection); + } catch (const std::invalid_argument &e) { + throw_exception(env, "java/lang/IllegalArgumentException", e.what()); + return 0; + } catch (const std::exception &e) { + throw_exception(env, "java/lang/IllegalStateException", e.what()); + return 0; + } +} + +extern "C" JNIEXPORT void JNICALL +Java_org_zvec_internal_jni_JniNative_close(JNIEnv *env, jclass, jlong address) { + try { + check(env, zvec_collection_close(handle(address)), "zvec_collection_close"); + } catch (const std::invalid_argument &e) { + throw_exception(env, "java/lang/IllegalArgumentException", e.what()); + } catch (const std::exception &e) { + throw_exception(env, "java/lang/IllegalStateException", e.what()); + } +} + +extern "C" JNIEXPORT void JNICALL +Java_org_zvec_internal_jni_JniNative_flush(JNIEnv *env, jclass, jlong address) { + try { + check(env, zvec_collection_flush(handle(address)), "zvec_collection_flush"); + } catch (const std::invalid_argument &e) { + throw_exception(env, "java/lang/IllegalArgumentException", e.what()); + } catch (const std::exception &e) { + throw_exception(env, "java/lang/IllegalStateException", e.what()); + } +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_zvec_internal_jni_JniNative_readSchema(JNIEnv *env, jclass, + jlong address) { + try { + zvec_collection_schema_t *schema = nullptr; + if (!check(env, zvec_collection_get_schema(handle(address), &schema), + "zvec_collection_get_schema")) { + return nullptr; + } + SchemaPtr owned(schema); + return schema_from_native(env, schema); + } catch (const std::invalid_argument &e) { + throw_exception(env, "java/lang/IllegalArgumentException", e.what()); + return nullptr; + } catch (const std::exception &e) { + throw_exception(env, "java/lang/IllegalStateException", e.what()); + return nullptr; + } +} + +extern "C" JNIEXPORT jint JNICALL +Java_org_zvec_internal_jni_JniNative_insert(JNIEnv *env, jclass, jlong address, + jobject schema, jobject docs) { + try { + int count = list_size(env, docs); + if (count == 0) { + return 0; + } + std::vector native_docs = docs_to_native(env, docs, schema); + if (env->ExceptionCheck()) { + return 0; + } + std::vector doc_ptrs; + doc_ptrs.reserve(native_docs.size()); + for (const auto &doc : native_docs) { + doc_ptrs.push_back(doc.get()); + } + size_t success_count = 0; + size_t error_count = 0; + if (!check(env, + zvec_collection_insert(handle(address), doc_ptrs.data(), + doc_ptrs.size(), &success_count, + &error_count), + "zvec_collection_insert")) { + return 0; + } + if (error_count != 0) { + throw_zvec(env, -1, + "zvec_collection_insert reported " + + std::to_string(error_count) + " per-document failures"); + return 0; + } + return static_cast(success_count); + } catch (const std::invalid_argument &e) { + throw_exception(env, "java/lang/IllegalArgumentException", e.what()); + return 0; + } catch (const std::exception &e) { + throw_exception(env, "java/lang/IllegalStateException", e.what()); + return 0; + } +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_zvec_internal_jni_JniNative_query(JNIEnv *env, jclass, jlong address, + jobject query_schema, + jobject result_schema, + jobject query) { + try { + QueryPtr native_query = query_to_native(env, query, query_schema, result_schema); + if (env->ExceptionCheck()) { + return nullptr; + } + zvec_doc_t **results = nullptr; + size_t result_count = 0; + if (!check(env, + zvec_collection_query(handle(address), native_query.get(), + &results, &result_count), + "zvec_collection_query")) { + return nullptr; + } + jobject out = new_array_list(env); + for (size_t i = 0; i < result_count; i++) { + list_add(env, out, doc_from_native(env, results[i], result_schema)); + if (env->ExceptionCheck()) { + break; + } + } + zvec_docs_free(results, result_count); + return out; + } catch (const std::invalid_argument &e) { + throw_exception(env, "java/lang/IllegalArgumentException", e.what()); + return nullptr; + } catch (const std::exception &e) { + throw_exception(env, "java/lang/IllegalStateException", e.what()); + return nullptr; + } +} diff --git a/java/zvec-java/zvec-java-jni/src/main/resources/META-INF/services/org.zvec.internal.NativeBackendProvider b/java/zvec-java/zvec-java-jni/src/main/resources/META-INF/services/org.zvec.internal.NativeBackendProvider new file mode 100644 index 000000000..3c48b9f13 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/main/resources/META-INF/services/org.zvec.internal.NativeBackendProvider @@ -0,0 +1 @@ +org.zvec.internal.jni.JniNativeBackendProvider diff --git a/src/binding/c/c_api.cc b/src/binding/c/c_api.cc index 2c3489ab9..d10fdc840 100644 --- a/src/binding/c/c_api.cc +++ b/src/binding/c/c_api.cc @@ -1948,8 +1948,7 @@ const char *zvec_collection_schema_get_name( return nullptr; } auto *cpp_schema = reinterpret_cast(schema); - // Use strdup to create a persistent copy since name() returns by value - return strdup(cpp_schema->name().c_str()); + return cpp_schema->name().c_str(); } zvec_error_code_t zvec_collection_schema_set_name(zvec_collection_schema_t *schema, diff --git a/src/include/zvec/c_api.h b/src/include/zvec/c_api.h index c64190d50..26bc690fb 100644 --- a/src/include/zvec/c_api.h +++ b/src/include/zvec/c_api.h @@ -2057,7 +2057,7 @@ zvec_collection_schema_destroy(zvec_collection_schema_t *schema); * * @note Returns a pointer to internal memory. Caller does NOT own the memory * and should NOT free it. The pointer is valid as long as the schema - * exists. + * exists and its name is not modified. */ ZVEC_EXPORT const char *ZVEC_CALL zvec_collection_schema_get_name(const zvec_collection_schema_t *schema); diff --git a/src/include/zvec/db/schema.h b/src/include/zvec/db/schema.h index bce2a1fd4..058694288 100644 --- a/src/include/zvec/db/schema.h +++ b/src/include/zvec/db/schema.h @@ -310,7 +310,7 @@ class CollectionSchema { std::string to_string_formatted(int indent_level = 0) const; - std::string name() const { + const std::string &name() const { return name_; } @@ -398,4 +398,4 @@ class CollectionSchema { uint64_t max_doc_count_per_segment_{MAX_DOC_COUNT_PER_SEGMENT}; }; -} // namespace zvec \ No newline at end of file +} // namespace zvec From a816f1b9143255be6533ca98c2feae24a8972996 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=A3=9E?= Date: Thu, 14 May 2026 11:38:02 -0400 Subject: [PATCH 2/5] test(java): add shared backend coverage --- ...stractCollectionInsertIntegrationTest.java | 84 ++++ ...actCollectionLifecycleIntegrationTest.java | 46 +++ ...bstractCollectionQueryIntegrationTest.java | 364 ++++++++++++++++++ ...currentEncryptedInsertIntegrationTest.java | 60 +++ ...tedCollectionRoundTripIntegrationTest.java | 52 +++ ...yptionMetadataMismatchIntegrationTest.java | 61 +++ .../AbstractKeyRotationIntegrationTest.java | 75 ++++ ...tLargeEncryptedPayloadIntegrationTest.java | 35 ++ ...bstractOpenWithoutKeysIntegrationTest.java | 36 ++ .../java/org/zvec/AbstractQuickStartTest.java | 53 +++ .../zvec/CollectionSchemaEncryptionTest.java | 35 ++ .../test/java/org/zvec/HnswDefaultsTest.java | 74 ++++ .../java/org/zvec/HnswParamsModelTest.java | 72 ++++ .../test/java/org/zvec/SchemaModelTest.java | 43 +++ .../java/org/zvec/ZvecSchemasBuilderTest.java | 73 ++++ .../org/zvec/ZvecSchemasEncryptionTest.java | 69 ++++ .../java/org/zvec/ZvecSearchBuilderTest.java | 59 +++ .../java/org/zvec/crypto/AadEncoderTest.java | 51 +++ .../AbstractCollectionSetActiveKeyIdTest.java | 76 ++++ ...ractZvecCreateAndOpenWithProviderTest.java | 89 +++++ .../crypto/AbstractZvecOpenWithKeysTest.java | 42 ++ .../java/org/zvec/crypto/AesGcm256Test.java | 131 +++++++ .../zvec/crypto/DecryptingProjectorTest.java | 106 +++++ .../org/zvec/crypto/EncryptedSchemaTest.java | 81 ++++ .../zvec/crypto/EncryptingInsertorTest.java | 99 +++++ .../zvec/crypto/EncryptionExceptionTest.java | 33 ++ .../zvec/crypto/EncryptionMetadataTest.java | 56 +++ .../org/zvec/crypto/EnvelopeCodecTest.java | 95 +++++ .../EnvelopeRelocationSecurityTest.java | 85 ++++ .../zvec/crypto/FilterFieldScannerTest.java | 51 +++ .../java/org/zvec/crypto/SidecarJsonTest.java | 69 ++++ .../org/zvec/crypto/SidecarMetadataTest.java | 56 +++ .../zvec/crypto/SingletonKeyProviderTest.java | 49 +++ .../org/zvec/internal/NativeBackendsTest.java | 117 ++++++ ...actCollectionConcurrentStressMainTest.java | 66 ++++ .../AbstractCollectionStressMainTest.java | 69 ++++ .../zvec/perf/CollectionStressChecksTest.java | 28 ++ .../zvec/perf/EncryptedFieldBenchmark.java | 53 +++ .../java/org/zvec/perf/LatencyStatsTest.java | 30 ++ .../java/org/zvec/perf/StressOptionsTest.java | 102 +++++ .../zvec/perf/ZvecJavaBindingBenchmark.java | 218 +++++++++++ .../FfmCollectionInsertIntegrationTest.java | 3 + ...FfmCollectionLifecycleIntegrationTest.java | 3 + .../FfmCollectionQueryIntegrationTest.java | 3 + ...currentEncryptedInsertIntegrationTest.java | 3 + ...tedCollectionRoundTripIntegrationTest.java | 3 + ...yptionMetadataMismatchIntegrationTest.java | 3 + .../zvec/FfmKeyRotationIntegrationTest.java | 3 + ...mLargeEncryptedPayloadIntegrationTest.java | 3 + .../FfmOpenWithoutKeysIntegrationTest.java | 3 + .../test/java/org/zvec/FfmQuickStartTest.java | 3 + .../FfmCollectionSetActiveKeyIdTest.java | 3 + .../FfmZvecCreateAndOpenWithProviderTest.java | 3 + .../zvec/crypto/FfmZvecOpenWithKeysTest.java | 3 + .../internal/ffm/FfmNativeLoaderTest.java | 57 +++ .../internal/ffm/FfmNativeVersionTest.java | 12 + .../org/zvec/internal/ffm/FfmQueriesTest.java | 62 +++ ...FfmCollectionConcurrentStressMainTest.java | 3 + .../perf/FfmCollectionStressMainTest.java | 3 + .../java/org/zvec/perf/FfmDocsBenchmark.java | 83 ++++ .../JniCollectionInsertIntegrationTest.java | 3 + ...JniCollectionLifecycleIntegrationTest.java | 3 + .../JniCollectionQueryIntegrationTest.java | 3 + ...currentEncryptedInsertIntegrationTest.java | 3 + ...tedCollectionRoundTripIntegrationTest.java | 3 + ...yptionMetadataMismatchIntegrationTest.java | 3 + .../zvec/JniKeyRotationIntegrationTest.java | 3 + ...iLargeEncryptedPayloadIntegrationTest.java | 3 + .../JniOpenWithoutKeysIntegrationTest.java | 3 + .../test/java/org/zvec/JniQuickStartTest.java | 3 + .../JniCollectionSetActiveKeyIdTest.java | 3 + .../JniZvecCreateAndOpenWithProviderTest.java | 3 + .../zvec/crypto/JniZvecOpenWithKeysTest.java | 3 + .../internal/jni/JniNativeLoaderTest.java | 59 +++ .../internal/jni/JniNativeVersionTest.java | 12 + ...JniCollectionConcurrentStressMainTest.java | 3 + .../perf/JniCollectionStressMainTest.java | 3 + 77 files changed, 3518 insertions(+) create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionInsertIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionLifecycleIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionQueryIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractConcurrentEncryptedInsertIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractEncryptedCollectionRoundTripIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractEncryptionMetadataMismatchIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractKeyRotationIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractLargeEncryptedPayloadIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractOpenWithoutKeysIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractQuickStartTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/CollectionSchemaEncryptionTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/HnswDefaultsTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/HnswParamsModelTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/SchemaModelTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSchemasBuilderTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSchemasEncryptionTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSearchBuilderTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AadEncoderTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractCollectionSetActiveKeyIdTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractZvecCreateAndOpenWithProviderTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractZvecOpenWithKeysTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AesGcm256Test.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/DecryptingProjectorTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptedSchemaTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptingInsertorTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptionExceptionTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptionMetadataTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EnvelopeCodecTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EnvelopeRelocationSecurityTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/FilterFieldScannerTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SidecarJsonTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SidecarMetadataTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SingletonKeyProviderTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/internal/NativeBackendsTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/AbstractCollectionConcurrentStressMainTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/AbstractCollectionStressMainTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/CollectionStressChecksTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/EncryptedFieldBenchmark.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/LatencyStatsTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/StressOptionsTest.java create mode 100644 java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/ZvecJavaBindingBenchmark.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionInsertIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionLifecycleIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionQueryIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmConcurrentEncryptedInsertIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmEncryptedCollectionRoundTripIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmEncryptionMetadataMismatchIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmKeyRotationIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmLargeEncryptedPayloadIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmOpenWithoutKeysIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmQuickStartTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmCollectionSetActiveKeyIdTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmZvecCreateAndOpenWithProviderTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmZvecOpenWithKeysTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmNativeLoaderTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmNativeVersionTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmQueriesTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmCollectionConcurrentStressMainTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmCollectionStressMainTest.java create mode 100644 java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmDocsBenchmark.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionInsertIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionLifecycleIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionQueryIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniConcurrentEncryptedInsertIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniEncryptedCollectionRoundTripIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniEncryptionMetadataMismatchIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniKeyRotationIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniLargeEncryptedPayloadIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniOpenWithoutKeysIntegrationTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniQuickStartTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniCollectionSetActiveKeyIdTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniZvecCreateAndOpenWithProviderTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniZvecOpenWithKeysTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/internal/jni/JniNativeLoaderTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/internal/jni/JniNativeVersionTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/perf/JniCollectionConcurrentStressMainTest.java create mode 100644 java/zvec-java/zvec-java-jni/src/test/java/org/zvec/perf/JniCollectionStressMainTest.java diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionInsertIntegrationTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionInsertIntegrationTest.java new file mode 100644 index 000000000..c74f1da10 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionInsertIntegrationTest.java @@ -0,0 +1,84 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.file.Path; +import java.util.List; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public abstract class AbstractCollectionInsertIntegrationTest { + @TempDir Path tempDir; + + @Test + void insertsDocumentsWithAllScalarTypesAndNullableField() { + assumeSupportedPlatform(); + + CollectionSchema schema = + new CollectionSchema( + "docs", + List.of( + new FieldSchema("title", DataType.STRING, false), + new FieldSchema("published", DataType.BOOL, false), + new FieldSchema("views", DataType.INT64, false), + new FieldSchema("rating", DataType.DOUBLE, false), + new FieldSchema("subtitle", DataType.STRING, true)), + List.of(new VectorSchema("embedding", DataType.VECTOR_FP32, 4))); + + try (Collection collection = Zvec.createAndOpen(tempDir.resolve("docs").toString(), schema)) { + int inserted = + collection.insert( + List.of( + Doc.of("doc_1") + .field("title", "alpha") + .field("published", true) + .field("views", 11L) + .field("rating", 4.5d) + .nullField("subtitle") + .vector("embedding", new float[] {1f, 0f, 0f, 0f}), + Doc.of("doc_2") + .field("title", "beta") + .field("published", false) + .field("views", 22L) + .field("rating", 3.25d) + .field("subtitle", "second") + .vector("embedding", new float[] {0f, 1f, 0f, 0f}))); + + assertEquals(2, inserted); + } + } + + @Test + void rejectsVectorDimensionMismatch() { + assumeSupportedPlatform(); + + CollectionSchema schema = + new CollectionSchema( + "docs", + List.of(new FieldSchema("title", DataType.STRING, false)), + List.of(new VectorSchema("embedding", DataType.VECTOR_FP32, 4))); + + try (Collection collection = Zvec.createAndOpen(tempDir.resolve("docs_bad").toString(), schema)) { + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, + () -> + collection.insert( + List.of( + Doc.of("doc_bad") + .field("title", "bad") + .vector("embedding", new float[] {1f, 0f})))); + assertTrue(ex.getMessage().contains("Vector dimension mismatch")); + } + } + + private static void assumeSupportedPlatform() { + String osName = System.getProperty("os.name", "").toLowerCase(); + String osArch = System.getProperty("os.arch", "").toLowerCase(); + Assumptions.assumeTrue(osName.contains("mac")); + Assumptions.assumeTrue(osArch.equals("aarch64") || osArch.equals("arm64")); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionLifecycleIntegrationTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionLifecycleIntegrationTest.java new file mode 100644 index 000000000..d451f192f --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionLifecycleIntegrationTest.java @@ -0,0 +1,46 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.file.Path; +import java.util.List; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public abstract class AbstractCollectionLifecycleIntegrationTest { + @TempDir Path tempDir; + + @Test + void createsAndReopensCollection() { + String osName = System.getProperty("os.name", "").toLowerCase(); + String osArch = System.getProperty("os.arch", "").toLowerCase(); + Assumptions.assumeTrue(osName.contains("mac")); + Assumptions.assumeTrue(osArch.equals("aarch64") || osArch.equals("arm64")); + + CollectionSchema schema = + new CollectionSchema( + "docs", + List.of( + new FieldSchema("title", DataType.STRING, false), + new FieldSchema("summary", DataType.STRING, true)), + List.of(new VectorSchema("embedding", DataType.VECTOR_FP32, 4))); + + Path collectionPath = tempDir.resolve("docs"); + try (Collection created = Zvec.createAndOpen(collectionPath.toString(), schema)) { + assertNotNull(created.schema()); + assertEquals("docs", created.schema().name()); + assertTrue(created.schema().field("summary").nullable()); + created.flush(); + } + + try (Collection reopened = Zvec.open(collectionPath.toString())) { + assertEquals("docs", reopened.schema().name()); + assertEquals(DataType.STRING, reopened.schema().field("title").dataType()); + assertTrue(reopened.schema().field("summary").nullable()); + assertEquals(4, reopened.schema().vector("embedding").dimension()); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionQueryIntegrationTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionQueryIntegrationTest.java new file mode 100644 index 000000000..6a5f50417 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractCollectionQueryIntegrationTest.java @@ -0,0 +1,364 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.file.Path; +import java.nio.file.Files; +import java.util.List; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.zvec.internal.HnswDefaults; + +public abstract class AbstractCollectionQueryIntegrationTest { + @TempDir Path tempDir; + + @Test + void queriesInsertedVectors() { + assumeSupportedPlatform(); + + CollectionSchema schema = + new CollectionSchema( + "docs", + List.of(new FieldSchema("title", DataType.STRING, false)), + List.of(new VectorSchema("embedding", DataType.VECTOR_FP32, 4))); + + try (Collection collection = Zvec.createAndOpen(tempDir.resolve("docs").toString(), schema)) { + collection.insert( + List.of( + Doc.of("doc_1").field("title", "alpha").vector("embedding", new float[] {1f, 0f, 0f, 0f}), + Doc.of("doc_2").field("title", "beta").vector("embedding", new float[] {0f, 1f, 0f, 0f}))); + + List results = + collection.query( + VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}).topK(2).outputFields("title")); + + assertFalse(results.isEmpty()); + assertEquals("doc_1", results.get(0).id()); + assertEquals("alpha", results.get(0).fields().get("title")); + assertNotNull(results.get(0).score()); + + List vectorResults = + collection.query( + VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}) + .topK(1) + .includeVector(true)); + + assertFalse(vectorResults.isEmpty()); + assertArrayEquals( + new float[] {1f, 0f, 0f, 0f}, vectorResults.get(0).vectors().get("embedding")); + } + } + + @Test + void queriesInsertedVectorsWithFluentSearchBuilder() { + assumeSupportedPlatform(); + + CollectionSchema schema = + new CollectionSchema( + "docs_fluent", + List.of(new FieldSchema("title", DataType.STRING, false)), + List.of(new VectorSchema("embedding", DataType.VECTOR_FP32, 4))); + + try (Collection collection = Zvec.createAndOpen(tempDir.resolve("docs_fluent").toString(), schema)) { + collection.insert( + List.of( + Doc.of("doc_1").field("title", "alpha").vector("embedding", new float[] {1f, 0f, 0f, 0f}), + Doc.of("doc_2").field("title", "beta").vector("embedding", new float[] {0f, 1f, 0f, 0f}))); + + List results = + collection.query( + ZvecSearch.vector("embedding", new float[] {1f, 0f, 0f, 0f}) + .topK(2) + .balanced() + .project("title") + .build()); + + assertFalse(results.isEmpty()); + assertEquals("doc_1", results.get(0).id()); + assertEquals("alpha", results.get(0).fields().get("title")); + assertNotNull(results.get(0).score()); + } + } + + @Test + void queriesInsertedVectorsWithExplicitHnswParams() { + assumeSupportedPlatform(); + + CollectionSchema schema = + new CollectionSchema( + "docs_hnsw", + List.of(new FieldSchema("title", DataType.STRING, false)), + List.of( + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withHnswIndex(new HnswIndexParams(16, 200)))); + + try (Collection collection = + Zvec.createAndOpen(tempDir.resolve("docs_hnsw").toString(), schema)) { + collection.insert( + List.of( + Doc.of("doc_1").field("title", "alpha").vector("embedding", new float[] {1f, 0f, 0f, 0f}), + Doc.of("doc_2").field("title", "beta").vector("embedding", new float[] {0f, 1f, 0f, 0f}))); + + List results = + collection.query( + VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}) + .topK(2) + .outputFields("title") + .hnsw(new HnswQueryParams(64, 0.0f, false, false))); + + assertFalse(results.isEmpty()); + assertEquals("doc_1", results.get(0).id()); + assertEquals("alpha", results.get(0).fields().get("title")); + } + } + + @Test + void createAndOpenPreservesVectorTuningHintsForSchemaLevelDefaults() { + assumeSupportedPlatform(); + + CollectionSchema schema = + new CollectionSchema( + "docs_tuning", + List.of(), + List.of( + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withTuningProfile(TuningProfile.ACCURATE, 1_000_000L))); + + try (Collection collection = + Zvec.createAndOpen(tempDir.resolve("docs_tuning").toString(), schema)) { + VectorSchema vector = collection.schema().vector("embedding"); + + assertEquals(TuningProfile.ACCURATE, vector.tuningProfile()); + assertEquals(Long.valueOf(1_000_000L), vector.expectedDocCount()); + assertEquals( + new HnswQueryParams(128, 0.0f, false, false), + HnswDefaults.resolveQueryParams( + vector, VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}))); + } + } + + @Test + void openReadsBackEffectiveRawHnswIndexParams() { + assumeSupportedPlatform(); + + Path path = tempDir.resolve("docs_reopen"); + CollectionSchema schema = + new CollectionSchema( + "docs_reopen", + List.of(), + List.of( + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withHnswIndex(new HnswIndexParams(18, 220)))); + + try (Collection ignored = Zvec.createAndOpen(path.toString(), schema)) { + } + + try (Collection reopened = Zvec.open(path.toString())) { + assertEquals( + new HnswIndexParams(18, 220), + reopened.schema().vector("embedding").hnswIndexParams()); + } + } + + @Test + void openPreservesSchemaLevelQueryDefaultsViaEffectiveIndexParams() { + assumeSupportedPlatform(); + + Path path = tempDir.resolve("docs_reopen_defaults"); + CollectionSchema schema = + new CollectionSchema( + "docs_reopen_defaults", + List.of(), + List.of( + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withTuningProfile(TuningProfile.ACCURATE, 1_000_000L))); + + try (Collection ignored = Zvec.createAndOpen(path.toString(), schema)) { + } + + try (Collection reopened = Zvec.open(path.toString())) { + VectorSchema vector = reopened.schema().vector("embedding"); + + assertNull(vector.hnswIndexParams()); + assertEquals(TuningProfile.ACCURATE, vector.tuningProfile()); + assertEquals(Long.valueOf(1_000_000L), vector.expectedDocCount()); + assertEquals( + new HnswQueryParams(128, 0.0f, false, false), + HnswDefaults.resolveQueryParams( + vector, + VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}))); + } + } + + @Test + void rejectsUnknownVectorFieldWithClearError() { + assumeSupportedPlatform(); + + CollectionSchema schema = + new CollectionSchema( + "docs_invalid_query", + List.of(new FieldSchema("title", DataType.STRING, false)), + List.of(new VectorSchema("embedding", DataType.VECTOR_FP32, 4))); + + try (Collection collection = + Zvec.createAndOpen(tempDir.resolve("docs_invalid_query").toString(), schema)) { + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, + () -> collection.query(VectorQuery.of("missing", new float[] {1f, 0f, 0f, 0f}))); + assertTrue(ex.getMessage().contains("Unknown vector field")); + } + } + + @Test + void openFallsBackToNativeSchemaWhenJavaMetadataIsMalformed() throws Exception { + assumeSupportedPlatform(); + + Path path = tempDir.resolve("docs_bad_metadata"); + CollectionSchema schema = + new CollectionSchema( + "docs_bad_metadata", + List.of(), + List.of( + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withHnswIndex(new HnswIndexParams(18, 220)))); + + try (Collection ignored = Zvec.createAndOpen(path.toString(), schema)) { + } + + Files.writeString( + path.resolve(".zvec-java-schema.properties"), + "version=1\nvector.embedding.tuningProfile=NOT_A_PROFILE\n"); + + try (Collection reopened = Zvec.open(path.toString())) { + assertEquals( + new HnswIndexParams(18, 220), + reopened.schema().vector("embedding").hnswIndexParams()); + } + } + + @Test + void openPrefersNativeRawHnswParamsOverJavaMetadata() throws Exception { + assumeSupportedPlatform(); + + Path path = tempDir.resolve("docs_stale_metadata"); + CollectionSchema schema = + new CollectionSchema( + "docs_stale_metadata", + List.of(), + List.of( + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withHnswIndex(new HnswIndexParams(18, 220)))); + + try (Collection ignored = Zvec.createAndOpen(path.toString(), schema)) { + } + + Files.writeString( + path.resolve(".zvec-java-schema.properties"), + "version=1\nvector.embedding.hnsw.m=32\nvector.embedding.hnsw.efConstruction=400\n"); + + try (Collection reopened = Zvec.open(path.toString())) { + assertEquals( + new HnswIndexParams(18, 220), + reopened.schema().vector("embedding").hnswIndexParams()); + } + } + + @Test + void openPreservesRawParamsWhenOnlyExpectedDocCountMetadataExists() { + assumeSupportedPlatform(); + + Path path = tempDir.resolve("docs_raw_with_doccount"); + CollectionSchema schema = + new CollectionSchema( + "docs_raw_with_doccount", + List.of(), + List.of( + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withHnswIndex(new HnswIndexParams(18, 220)) + .withExpectedDocCount(1_000_000L))); + + try (Collection ignored = Zvec.createAndOpen(path.toString(), schema)) { + } + + try (Collection reopened = Zvec.open(path.toString())) { + VectorSchema vector = reopened.schema().vector("embedding"); + + assertEquals(new HnswIndexParams(18, 220), vector.hnswIndexParams()); + assertNull(vector.tuningProfile()); + assertEquals(Long.valueOf(1_000_000L), vector.expectedDocCount()); + } + } + + @Test + void openPreservesExpectedDocCountOnlyHintState() { + assumeSupportedPlatform(); + + Path path = tempDir.resolve("docs_doccount_only"); + CollectionSchema schema = + new CollectionSchema( + "docs_doccount_only", + List.of(), + List.of( + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withExpectedDocCount(1_000_000L))); + + try (Collection ignored = Zvec.createAndOpen(path.toString(), schema)) { + } + + try (Collection reopened = Zvec.open(path.toString())) { + VectorSchema vector = reopened.schema().vector("embedding"); + + assertNull(vector.hnswIndexParams()); + assertNull(vector.tuningProfile()); + assertEquals(Long.valueOf(1_000_000L), vector.expectedDocCount()); + assertEquals( + new HnswQueryParams(96, 0.0f, false, false), + HnswDefaults.resolveQueryParams( + vector, VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}))); + } + } + + @Test + void openFallsBackToNativeSchemaWhenMetadataOmitsRawStateFlag() throws Exception { + assumeSupportedPlatform(); + + Path path = tempDir.resolve("docs_missing_raw_flag"); + CollectionSchema schema = + new CollectionSchema( + "docs_missing_raw_flag", + List.of(), + List.of( + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withExpectedDocCount(1_000_000L))); + + try (Collection ignored = Zvec.createAndOpen(path.toString(), schema)) { + } + + Files.writeString( + path.resolve(".zvec-java-schema.properties"), + "version=1\nvector.embedding.expectedDocCount=1000000\n"); + + try (Collection reopened = Zvec.open(path.toString())) { + VectorSchema vector = reopened.schema().vector("embedding"); + + assertEquals(new HnswIndexParams(24, 300), vector.hnswIndexParams()); + assertNull(vector.tuningProfile()); + assertNull(vector.expectedDocCount()); + } + } + + private static void assumeSupportedPlatform() { + String osName = System.getProperty("os.name", "").toLowerCase(); + String osArch = System.getProperty("os.arch", "").toLowerCase(); + Assumptions.assumeTrue(osName.contains("mac")); + Assumptions.assumeTrue(osArch.equals("aarch64") || osArch.equals("arm64")); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractConcurrentEncryptedInsertIntegrationTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractConcurrentEncryptedInsertIntegrationTest.java new file mode 100644 index 000000000..929414a72 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractConcurrentEncryptedInsertIntegrationTest.java @@ -0,0 +1,60 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.zvec.crypto.KeyProvider; + +public abstract class AbstractConcurrentEncryptedInsertIntegrationTest { + + @Test + void eightThreadsThousandEachAllDecryptCorrectly(@TempDir Path tmp) throws Exception { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + + int threads = 8; + int perThread = 1000; + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) { + ExecutorService pool = Executors.newFixedThreadPool(threads); + List> futures = new ArrayList<>(); + for (int t = 0; t < threads; t++) { + final int tid = t; + futures.add(pool.submit(() -> { + List batch = new ArrayList<>(perThread); + for (int i = 0; i < perThread; i++) { + String id = "t" + tid + "-i" + i; + batch.add(Doc.of(id) + .field("body", "secret-" + id) + .vector("e", new float[] {1f, 0f, 0f, 0f})); + } + return col.insert(batch); + })); + } + for (Future f : futures) f.get(); + pool.shutdown(); + + // Sample-based verification: query a moderate topK (well under total dataset + // size to avoid native HNSW edge cases) and confirm each result decrypts to + // its expected plaintext. This catches nonce-reuse, Cipher-reuse, or other + // concurrency bugs in the encryption layer without exercising the + // top-K-equals-N native code path. + List results = col.query( + ZvecSearch.vector("e", new float[] {1f, 0f, 0f, 0f}) + .topK(100).project("body").build()); + assertEquals(100, results.size()); + for (Doc d : results) { + assertEquals("secret-" + d.id(), d.fields().get("body")); + } + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractEncryptedCollectionRoundTripIntegrationTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractEncryptedCollectionRoundTripIntegrationTest.java new file mode 100644 index 000000000..3da93539b --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractEncryptedCollectionRoundTripIntegrationTest.java @@ -0,0 +1,52 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.zvec.crypto.KeyProvider; + +public abstract class AbstractEncryptedCollectionRoundTripIntegrationTest { + + @Test + void thousandDocsRoundTripWithBodyEncrypted(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("title").string("body").encrypted("k1") + .vector("embed", 4).balanced() + .build(); + KeyProvider provider = kid -> { + byte[] k = new byte[32]; + k[0] = 7; + return k; + }; + + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) { + List docs = new ArrayList<>(); + for (int i = 0; i < 1000; i++) { + docs.add(Doc.of("d" + i) + .field("title", "t" + i) + .field("body", "secret-" + i) + .vector("embed", new float[] {1f, 0f, 0f, (float) i / 1000f})); + } + col.insert(docs); + + List results = col.query( + ZvecSearch.vector("embed", new float[] {1f, 0f, 0f, 0f}) + .topK(10).project("title", "body").build()); + + assertEquals(10, results.size()); + for (Doc d : results) { + String id = d.id(); + assertTrue(id.startsWith("d")); + int idx = Integer.parseInt(id.substring(1)); + assertEquals("t" + idx, d.fields().get("title")); + assertEquals("secret-" + idx, d.fields().get("body")); + } + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractEncryptionMetadataMismatchIntegrationTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractEncryptionMetadataMismatchIntegrationTest.java new file mode 100644 index 000000000..c452148be --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractEncryptionMetadataMismatchIntegrationTest.java @@ -0,0 +1,61 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.zvec.crypto.EncryptionMetadataIOException; +import org.zvec.crypto.EncryptionMetadataMismatchException; +import org.zvec.crypto.KeyProvider; +import org.zvec.crypto.SidecarMetadata; + +public abstract class AbstractEncryptionMetadataMismatchIntegrationTest { + + @Test + void corruptedSidecarSurfacesAsIOException(@TempDir Path tmp) throws Exception { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) {} + + Files.writeString(path.resolve(SidecarMetadata.FILENAME), "totally-not-json{"); + assertThrows(EncryptionMetadataIOException.class, + () -> Zvec.openWithKeys(path.toString(), provider)); + } + + @Test + void renamedFieldInSidecarSurfacesAsMismatch(@TempDir Path tmp) throws Exception { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) {} + + String text = Files.readString(path.resolve(SidecarMetadata.FILENAME)); + Files.writeString(path.resolve(SidecarMetadata.FILENAME), + text.replace("\"body\":", "\"renamed_field\":")); + assertThrows(EncryptionMetadataMismatchException.class, + () -> Zvec.openWithKeys(path.toString(), provider)); + } + + @Test + void mismatchedCollectionNameSurfacesAsMismatch(@TempDir Path tmp) throws Exception { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) {} + + String text = Files.readString(path.resolve(SidecarMetadata.FILENAME)); + Files.writeString(path.resolve(SidecarMetadata.FILENAME), + text.replace("\"docs\"", "\"OTHER\"")); + assertThrows(EncryptionMetadataMismatchException.class, + () -> Zvec.openWithKeys(path.toString(), provider)); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractKeyRotationIntegrationTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractKeyRotationIntegrationTest.java new file mode 100644 index 000000000..61fc23b7d --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractKeyRotationIntegrationTest.java @@ -0,0 +1,75 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.file.Path; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.zvec.crypto.KeyProvider; +import org.zvec.crypto.KeyResolutionException; + +public abstract class AbstractKeyRotationIntegrationTest { + + @Test + void oldRecordsDecryptableUnderNewActiveKey(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + + Map keys = new HashMap<>(); + byte[] k1 = new byte[32]; k1[0] = 1; + byte[] k2 = new byte[32]; k2[0] = 2; + keys.put("k1", k1); + keys.put("k2", k2); + KeyProvider provider = kid -> keys.get(kid); + + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) { + col.insert(List.of(Doc.of("d1").field("body", "old-text") + .vector("e", new float[] {1f, 0f, 0f, 0f}))); + + col.setActiveKeyId("body", "k2"); + col.insert(List.of(Doc.of("d2").field("body", "new-text") + .vector("e", new float[] {1f, 0f, 0f, 0f}))); + + List results = col.query( + ZvecSearch.vector("e", new float[] {1f, 0f, 0f, 0f}) + .topK(2).project("body").build()); + assertEquals(2, results.size()); + Map byId = new HashMap<>(); + for (Doc d : results) byId.put(d.id(), (String) d.fields().get("body")); + assertEquals("old-text", byId.get("d1")); + assertEquals("new-text", byId.get("d2")); + } + } + + @Test + void revokingOldKeyMakesOldRecordsUnreadable(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + + Map keys = new HashMap<>(); + byte[] k1 = new byte[32]; k1[0] = 1; + byte[] k2 = new byte[32]; k2[0] = 2; + keys.put("k1", k1); + keys.put("k2", k2); + + try (Collection col = Zvec.createAndOpen(path.toString(), schema, kid -> keys.get(kid))) { + col.insert(List.of(Doc.of("d1").field("body", "old") + .vector("e", new float[] {1f, 0f, 0f, 0f}))); + } + + keys.remove("k1"); + try (Collection col = Zvec.openWithKeys(path.toString(), kid -> keys.get(kid))) { + assertThrows(KeyResolutionException.class, () -> + col.query(ZvecSearch.vector("e", new float[] {1f, 0f, 0f, 0f}) + .topK(1).project("body").build())); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractLargeEncryptedPayloadIntegrationTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractLargeEncryptedPayloadIntegrationTest.java new file mode 100644 index 000000000..405fd033f --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractLargeEncryptedPayloadIntegrationTest.java @@ -0,0 +1,35 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.nio.file.Path; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.zvec.crypto.KeyProvider; + +public abstract class AbstractLargeEncryptedPayloadIntegrationTest { + + @Test + void oneMegabytePayloadRoundTrips(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + + char[] chars = new char[1024 * 1024]; + for (int i = 0; i < chars.length; i++) chars[i] = (char) ('a' + (i % 26)); + String big = new String(chars); + + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) { + col.insert(List.of(Doc.of("d1").field("body", big) + .vector("e", new float[] {1f, 0f, 0f, 0f}))); + + List results = col.query( + ZvecSearch.vector("e", new float[] {1f, 0f, 0f, 0f}) + .topK(1).project("body").build()); + assertEquals(big, results.get(0).fields().get("body")); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractOpenWithoutKeysIntegrationTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractOpenWithoutKeysIntegrationTest.java new file mode 100644 index 000000000..210c02b0a --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractOpenWithoutKeysIntegrationTest.java @@ -0,0 +1,36 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.file.Path; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.zvec.crypto.EncryptedCollectionException; +import org.zvec.crypto.KeyProvider; + +public abstract class AbstractOpenWithoutKeysIntegrationTest { + + @Test + void plainOpenOnEncryptedCollectionThrows(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) {} + + assertThrows(EncryptedCollectionException.class, () -> Zvec.open(path.toString())); + } + + @Test + void openWithKeysWithNullProviderThrows(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) {} + + assertThrows(NullPointerException.class, () -> Zvec.openWithKeys(path.toString(), null)); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractQuickStartTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractQuickStartTest.java new file mode 100644 index 000000000..8cc933a38 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/AbstractQuickStartTest.java @@ -0,0 +1,53 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import java.nio.file.Path; +import java.util.List; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public abstract class AbstractQuickStartTest { + @TempDir Path tempDir; + + @Test + void quickStartFlowWorks() { + assumeSupportedPlatform(); + + CollectionSchema schema = + ZvecSchemas.collection("docs").string("title").vector("embedding", 4).balanced().build(); + + try (Collection collection = Zvec.createAndOpen(tempDir.resolve("docs").toString(), schema)) { + collection.insert( + List.of( + Doc.of("doc_1") + .field("title", "alpha") + .vector("embedding", new float[] {1f, 0f, 0f, 0f}), + Doc.of("doc_2") + .field("title", "beta") + .vector("embedding", new float[] {0f, 1f, 0f, 0f}))); + + List results = + collection.query( + ZvecSearch.vector("embedding", new float[] {1f, 0f, 0f, 0f}) + .topK(2) + .project("title") + .build()); + + assertFalse(results.isEmpty()); + assertEquals("doc_1", results.get(0).id()); + assertEquals("alpha", results.get(0).fields().get("title")); + assertNotNull(results.get(0).score()); + } + } + + private static void assumeSupportedPlatform() { + String osName = System.getProperty("os.name", "").toLowerCase(); + String osArch = System.getProperty("os.arch", "").toLowerCase(); + Assumptions.assumeTrue(osName.contains("mac")); + Assumptions.assumeTrue(osArch.equals("aarch64") || osArch.equals("arm64")); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/CollectionSchemaEncryptionTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/CollectionSchemaEncryptionTest.java new file mode 100644 index 000000000..e3639f01f --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/CollectionSchemaEncryptionTest.java @@ -0,0 +1,35 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.zvec.crypto.EncryptionMetadata; +import org.zvec.crypto.EncryptionSpec; + +class CollectionSchemaEncryptionTest { + @Test + void existingThreeArgConstructorReturnsEmptyEncryption() { + CollectionSchema s = new CollectionSchema( + "docs", + List.of(new FieldSchema("title", DataType.STRING, false)), + List.of(new VectorSchema("e", DataType.VECTOR_FP32, 4))); + assertEquals(Optional.empty(), s.encryption()); + } + + @Test + void fourArgConstructorRetainsMetadata() { + EncryptionSpec spec = new EncryptionSpec("AES-256-GCM", "k1", Instant.now(), null); + EncryptionMetadata meta = new EncryptionMetadata(1, "docs", Map.of("title", spec)); + CollectionSchema s = new CollectionSchema( + "docs", + List.of(new FieldSchema("title", DataType.STRING, false)), + List.of(new VectorSchema("e", DataType.VECTOR_FP32, 4)), + meta); + assertSame(meta, s.encryption().orElseThrow()); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/HnswDefaultsTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/HnswDefaultsTest.java new file mode 100644 index 000000000..fc3d674ed --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/HnswDefaultsTest.java @@ -0,0 +1,74 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.zvec.internal.HnswDefaults; + +class HnswDefaultsTest { + @Test + void resolvesBalancedDefaultsWhenNoHintsAreProvided() { + VectorSchema schema = new VectorSchema("embedding", DataType.VECTOR_FP32, 128); + VectorQuery query = VectorQuery.of("embedding", new float[] {1.0f, 2.0f}); + + assertEquals(new HnswIndexParams(16, 200), HnswDefaults.resolveIndexParams(schema)); + assertEquals( + new HnswQueryParams(64, 0.0f, false, false), HnswDefaults.resolveQueryParams(schema, query)); + } + + @Test + void resolvesProfileAndDocCountDefaults() { + VectorSchema schema = + new VectorSchema("embedding", DataType.VECTOR_FP32, 128) + .withTuningProfile(TuningProfile.FAST, 100_000L); + VectorQuery query = VectorQuery.of("embedding", new float[] {1.0f, 2.0f}); + + assertEquals(new HnswIndexParams(12, 120), HnswDefaults.resolveIndexParams(schema)); + assertEquals( + new HnswQueryParams(32, 0.0f, false, false), HnswDefaults.resolveQueryParams(schema, query)); + } + + @Test + void respectsExplicitProfilesWithoutExpectedDocCount() { + VectorSchema schema = + new VectorSchema("embedding", DataType.VECTOR_FP32, 128) + .withTuningProfile(TuningProfile.ACCURATE); + VectorQuery query = + VectorQuery.of("embedding", new float[] {1.0f, 2.0f}) + .withTuningProfile(TuningProfile.FAST); + + assertEquals(new HnswIndexParams(24, 300), HnswDefaults.resolveIndexParams(schema)); + assertEquals( + new HnswQueryParams(32, 0.0f, false, false), HnswDefaults.resolveQueryParams(schema, query)); + } + + @Test + void queryProfileOverridesSchemaProfileWhenResolvingQueryDefaults() { + VectorSchema schema = + new VectorSchema("embedding", DataType.VECTOR_FP32, 128) + .withTuningProfile(TuningProfile.FAST, 1_000_000L); + VectorQuery query = + VectorQuery.of("embedding", new float[] {1.0f, 2.0f}) + .withTuningProfile(TuningProfile.ACCURATE); + + assertEquals( + new HnswQueryParams(128, 0.0f, false, false), + HnswDefaults.resolveQueryParams(schema, query)); + } + + @Test + void rawParamsOverrideProfilesAndDocCountHints() { + VectorSchema schema = + new VectorSchema("embedding", DataType.VECTOR_FP32, 128) + .withTuningProfile(TuningProfile.ACCURATE, 1_000_000L) + .withHnswIndex(new HnswIndexParams(18, 220)); + VectorQuery query = + VectorQuery.of("embedding", new float[] {1.0f, 2.0f}) + .withTuningProfile(TuningProfile.BALANCED) + .hnsw(new HnswQueryParams(144, 0.0f, false, true)); + + assertEquals(new HnswIndexParams(18, 220), HnswDefaults.resolveIndexParams(schema)); + assertEquals( + new HnswQueryParams(144, 0.0f, false, true), HnswDefaults.resolveQueryParams(schema, query)); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/HnswParamsModelTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/HnswParamsModelTest.java new file mode 100644 index 000000000..6e3274801 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/HnswParamsModelTest.java @@ -0,0 +1,72 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +class HnswParamsModelTest { + @Test + void hnswIndexParamsRejectNonPositiveValues() { + assertThrows(IllegalArgumentException.class, () -> new HnswIndexParams(0, 100)); + assertThrows(IllegalArgumentException.class, () -> new HnswIndexParams(16, 0)); + } + + @Test + void hnswQueryParamsRejectInvalidEfAndRadius() { + assertThrows(IllegalArgumentException.class, () -> new HnswQueryParams(0, 0.0f, false, false)); + assertThrows( + IllegalArgumentException.class, () -> new HnswQueryParams(64, -1.0f, false, false)); + } + + @Test + void vectorSchemaCanCarryOptionalHnswIndexParams() { + VectorSchema schema = new VectorSchema("embedding", DataType.VECTOR_FP32, 128); + HnswIndexParams params = new HnswIndexParams(32, 300); + + VectorSchema configured = schema.withHnswIndex(params); + + assertNull(schema.hnswIndexParams()); + assertEquals(params, configured.hnswIndexParams()); + } + + @Test + void vectorSchemaClearsTuningHintsWhenRawIndexParamsAreApplied() { + VectorSchema schema = + new VectorSchema("embedding", DataType.VECTOR_FP32, 128) + .withTuningProfile(TuningProfile.ACCURATE, 1_000_000L); + HnswIndexParams params = new HnswIndexParams(32, 300); + + VectorSchema configured = schema.withHnswIndex(params); + + assertEquals(TuningProfile.ACCURATE, schema.tuningProfile()); + assertEquals(Long.valueOf(1_000_000L), schema.expectedDocCount()); + assertNull(configured.tuningProfile()); + assertNull(configured.expectedDocCount()); + assertEquals(params, configured.hnswIndexParams()); + } + + @Test + void vectorQueryCanCarryOptionalHnswQueryParams() { + HnswQueryParams params = new HnswQueryParams(128, 0.0f, false, true); + + VectorQuery query = VectorQuery.of("embedding", new float[] {1.0f, 2.0f}).hnsw(params); + + assertSame(params, query.hnswQueryParams()); + } + + @Test + void vectorQueryClearsTuningHintsWhenRawHnswParamsAreApplied() { + HnswQueryParams params = new HnswQueryParams(128, 0.0f, false, true); + + VectorQuery query = + VectorQuery.of("embedding", new float[] {1.0f, 2.0f}) + .withTuningProfile(TuningProfile.BALANCED) + .hnsw(params); + + assertNull(query.tuningProfile()); + assertSame(params, query.hnswQueryParams()); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/SchemaModelTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/SchemaModelTest.java new file mode 100644 index 000000000..9fee7b71a --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/SchemaModelTest.java @@ -0,0 +1,43 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; +import org.junit.jupiter.api.Test; + +class SchemaModelTest { + @Test + void duplicateFieldNamesAcrossScalarAndVectorDefinitionsThrow() { + FieldSchema field = new FieldSchema("shared", DataType.STRING, false); + VectorSchema vector = new VectorSchema("shared", DataType.VECTOR_FP32, 128); + + assertThrows( + IllegalArgumentException.class, + () -> new CollectionSchema("items", List.of(field), List.of(vector))); + } + + @Test + void nonVectorDataTypeForVectorSchemaThrows() { + assertThrows( + IllegalArgumentException.class, () -> new VectorSchema("embedding", DataType.STRING, 128)); + } + + @Test + void scalarAndVectorDefinitionsAreStoredAndRetrievable() { + FieldSchema id = new FieldSchema("id", DataType.STRING, false); + FieldSchema active = new FieldSchema("active", DataType.BOOL, true); + VectorSchema embedding = new VectorSchema("embedding", DataType.VECTOR_FP32, 1536); + + CollectionSchema schema = + new CollectionSchema("items", List.of(id, active), List.of(embedding)); + + assertEquals("items", schema.name()); + assertEquals(List.of(id, active), schema.fields()); + assertEquals(List.of(embedding), schema.vectors()); + assertSame(id, schema.field("id")); + assertSame(active, schema.field("active")); + assertSame(embedding, schema.vector("embedding")); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSchemasBuilderTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSchemasBuilderTest.java new file mode 100644 index 000000000..0e0872bc2 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSchemasBuilderTest.java @@ -0,0 +1,73 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; +import org.junit.jupiter.api.Test; + +class ZvecSchemasBuilderTest { + @Test + void buildsSchemaWithBalancedVectorDefaults() { + CollectionSchema schema = + ZvecSchemas.collection("items") + .string("title") + .bool("active") + .int64("rank") + .doubleField("score") + .vector("embedding", 4) + .expectedDocCount(1_000_000L) + .balanced() + .build(); + + assertEquals("items", schema.name()); + assertEquals( + List.of( + new FieldSchema("title", DataType.STRING, false), + new FieldSchema("active", DataType.BOOL, false), + new FieldSchema("rank", DataType.INT64, false), + new FieldSchema("score", DataType.DOUBLE, false)), + schema.fields()); + assertEquals(1, schema.vectors().size()); + assertEquals("embedding", schema.vector("embedding").name()); + assertEquals(4, schema.vector("embedding").dimension()); + assertEquals(TuningProfile.BALANCED, schema.vector("embedding").tuningProfile()); + assertEquals(Long.valueOf(1_000_000L), schema.vector("embedding").expectedDocCount()); + } + + @Test + void tuningMethodsRequireAnActiveVectorField() { + ZvecSchemas.Builder builder = ZvecSchemas.collection("items").string("title"); + + IllegalStateException fast = + assertThrows(IllegalStateException.class, builder::fast); + IllegalStateException balanced = + assertThrows(IllegalStateException.class, builder::balanced); + IllegalStateException accurate = + assertThrows(IllegalStateException.class, builder::accurate); + IllegalStateException expectedDocCount = + assertThrows(IllegalStateException.class, () -> builder.expectedDocCount(10L)); + + assertEquals("fast() must follow vector(name, dimension)", fast.getMessage()); + assertEquals("balanced() must follow vector(name, dimension)", balanced.getMessage()); + assertEquals("accurate() must follow vector(name, dimension)", accurate.getMessage()); + assertEquals( + "expectedDocCount(...) must follow vector(name, dimension)", + expectedDocCount.getMessage()); + } + + @Test + void laterProfilesOverrideEarlierProfilesOnTheActiveVector() { + CollectionSchema schema = + ZvecSchemas.collection("items") + .vector("first", 2) + .fast() + .vector("second", 4) + .balanced() + .accurate() + .build(); + + assertEquals(TuningProfile.FAST, schema.vector("first").tuningProfile()); + assertEquals(TuningProfile.ACCURATE, schema.vector("second").tuningProfile()); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSchemasEncryptionTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSchemasEncryptionTest.java new file mode 100644 index 000000000..420a738ec --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSchemasEncryptionTest.java @@ -0,0 +1,69 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.zvec.crypto.EncryptionMetadata; + +class ZvecSchemasEncryptionTest { + @Test + void encryptedAttachesToPrecedingStringField() { + CollectionSchema s = ZvecSchemas.collection("docs") + .string("title") + .string("body").encrypted("body-key-v1") + .vector("embed", 4).balanced() + .build(); + EncryptionMetadata meta = s.encryption().orElseThrow(); + assertEquals(1, meta.fields().size()); + assertEquals("body-key-v1", meta.spec("body").activeKeyId()); + assertEquals("docs", meta.collectionName()); + assertEquals("AES-256-GCM", meta.spec("body").alg()); + } + + @Test + void encryptedRequiresPrecedingStringField() { + ZvecSchemas.Builder b1 = ZvecSchemas.collection("docs"); + assertThrows(IllegalStateException.class, () -> b1.encrypted("k")); + + ZvecSchemas.Builder b2 = ZvecSchemas.collection("docs").vector("e", 4); + assertThrows(IllegalStateException.class, () -> b2.encrypted("k")); + + ZvecSchemas.Builder b3 = ZvecSchemas.collection("docs").int64("salary"); + assertThrows(org.zvec.crypto.UnsupportedFieldTypeException.class, () -> b3.encrypted("k")); + } + + @Test + void rejectsDuplicateEncryptionOnSameField() { + ZvecSchemas.Builder b = ZvecSchemas.collection("docs") + .string("body").encrypted("k1"); + assertThrows(IllegalStateException.class, () -> b.encrypted("k2")); + } + + @Test + void schemaWithoutEncryptedFieldsHasEmptyEncryption() { + CollectionSchema s = ZvecSchemas.collection("docs") + .string("title").vector("e", 4).build(); + assertTrue(s.encryption().isEmpty()); + } + + @Test + void encryptedWithStaticKeyEmbedsProvider() { + byte[] key = new byte[32]; + key[0] = 7; + CollectionSchema s = ZvecSchemas.collection("docs") + .string("body").encrypted("k1", key) + .vector("e", 4).build(); + org.zvec.crypto.KeyProvider provider = s.embeddedKeyProviders().orElseThrow().get("body"); + assertEquals(7, provider.resolve("k1")[0]); + org.junit.jupiter.api.Assertions.assertNull(provider.resolve("other")); + } + + @Test + void encryptedWithStaticKeyRejectsBadLength() { + ZvecSchemas.Builder b = ZvecSchemas.collection("docs").string("body"); + assertThrows(IllegalArgumentException.class, () -> b.encrypted("k1", new byte[16])); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSearchBuilderTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSearchBuilderTest.java new file mode 100644 index 000000000..85d1cdaed --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/ZvecSearchBuilderTest.java @@ -0,0 +1,59 @@ +package org.zvec; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNull; + +import org.junit.jupiter.api.Test; + +class ZvecSearchBuilderTest { + @Test + void buildsBalancedProjectedQuery() { + VectorQuery query = + ZvecSearch.vector("embedding", new float[] {1f, 0f, 0f, 0f}) + .topK(7) + .balanced() + .project("title", "category") + .build(); + + assertEquals("embedding", query.fieldName()); + assertArrayEquals(new float[] {1f, 0f, 0f, 0f}, query.queryVector()); + assertEquals(7, query.topK()); + assertEquals(TuningProfile.BALANCED, query.tuningProfile()); + assertTrue(query.outputFieldsSpecified()); + assertEquals(java.util.List.of("title", "category"), query.outputFields()); + assertFalse(query.includeVector()); + assertNull(query.filter()); + assertNull(query.hnswQueryParams()); + } + + @Test + void laterProfilesOverrideEarlierProfiles() { + VectorQuery query = + ZvecSearch.vector("embedding", new float[] {1f, 0f, 0f, 0f}) + .fast() + .accurate() + .build(); + + assertNull(query.hnswQueryParams()); + assertEquals(TuningProfile.ACCURATE, query.tuningProfile()); + } + + @Test + void includeVectorAndFilterAreCarriedIntoVectorQuery() { + VectorQuery query = + ZvecSearch.vector("embedding", new float[] {1f, 0f, 0f, 0f}) + .fast() + .includeVector() + .filter("title = 'alpha'") + .build(); + + assertTrue(query.includeVector()); + assertEquals("title = 'alpha'", query.filter()); + assertFalse(query.outputFieldsSpecified()); + assertEquals(TuningProfile.FAST, query.tuningProfile()); + assertNull(query.hnswQueryParams()); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AadEncoderTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AadEncoderTest.java new file mode 100644 index 000000000..95b29f82e --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AadEncoderTest.java @@ -0,0 +1,51 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Test; + +class AadEncoderTest { + @Test + void encodesLengthPrefixedConcatInBigEndian() { + byte[] aad = AadEncoder.encode("d1", "body", "docs"); + + byte[] id = "d1".getBytes(StandardCharsets.UTF_8); + byte[] field = "body".getBytes(StandardCharsets.UTF_8); + byte[] coll = "docs".getBytes(StandardCharsets.UTF_8); + ByteBuffer expected = ByteBuffer.allocate(12 + id.length + field.length + coll.length); + expected.putInt(id.length).put(id); + expected.putInt(field.length).put(field); + expected.putInt(coll.length).put(coll); + + assertArrayEquals(expected.array(), aad); + } + + @Test + void embeddedUnitSeparatorByteIsTreatedAsData() { + // Field name with an embedded unit-separator (0x1F) byte; the length-prefix + // encoding treats it as plain data, not as a delimiter. + byte[] withSep = AadEncoder.encode("d1", "bodydocs", "docs"); + byte[] without = AadEncoder.encode("d1", "body", "docs"); + assertNotEquals(withSep.length, without.length); + } + + @Test + void rejectsNullArgs() { + assertThrows(NullPointerException.class, () -> AadEncoder.encode(null, "f", "c")); + assertThrows(NullPointerException.class, () -> AadEncoder.encode("d", null, "c")); + assertThrows(NullPointerException.class, () -> AadEncoder.encode("d", "f", null)); + } + + @Test + void supportsEmoji() { + byte[] aad = AadEncoder.encode("👍", "body", "docs"); + byte[] thumb = "👍".getBytes(StandardCharsets.UTF_8); + ByteBuffer first4 = ByteBuffer.wrap(aad, 0, 4); + int len = first4.getInt(); + org.junit.jupiter.api.Assertions.assertEquals(thumb.length, len); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractCollectionSetActiveKeyIdTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractCollectionSetActiveKeyIdTest.java new file mode 100644 index 000000000..7fd0570bb --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractCollectionSetActiveKeyIdTest.java @@ -0,0 +1,76 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.file.Path; +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.zvec.Collection; +import org.zvec.CollectionSchema; +import org.zvec.Zvec; +import org.zvec.ZvecSchemas; + +public abstract class AbstractCollectionSetActiveKeyIdTest { + + @Test + void rotateUpdatesSidecar(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) { + col.setActiveKeyId("body", "k2"); + assertEquals("k2", col.encryptedSchema().activeKeyId("body")); + } + Optional back = SidecarMetadata.read(path); + assertNotNull(back.orElseThrow()); + assertEquals("k2", back.get().spec("body").activeKeyId()); + assertNotNull(back.get().spec("body").rotatedAt()); + } + + @Test + void rotateRejectsNonEncryptedField(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) { + assertThrows(IllegalArgumentException.class, () -> col.setActiveKeyId("title", "k2")); + } + } + + @Test + void postRotationReadStillDecryptsOldRecords(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + org.zvec.CollectionSchema schema = org.zvec.ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + + java.util.Map keys = new java.util.HashMap<>(); + byte[] k1 = new byte[32]; k1[0] = 1; + byte[] k2 = new byte[32]; k2[0] = 2; + keys.put("k1", k1); + keys.put("k2", k2); + KeyProvider provider = kid -> keys.get(kid); + + try (org.zvec.Collection col = org.zvec.Zvec.createAndOpen(path.toString(), schema, provider)) { + col.insert(java.util.List.of( + org.zvec.Doc.of("d1").field("body", "secret-old") + .vector("e", new float[] {1f, 0f, 0f, 0f}))); + + col.setActiveKeyId("body", "k2"); + + java.util.List results = col.query( + org.zvec.ZvecSearch.vector("e", new float[] {1f, 0f, 0f, 0f}) + .topK(1).project("body").build()); + assertEquals(1, results.size()); + assertEquals("secret-old", results.get(0).fields().get("body")); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractZvecCreateAndOpenWithProviderTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractZvecCreateAndOpenWithProviderTest.java new file mode 100644 index 000000000..b225dd333 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractZvecCreateAndOpenWithProviderTest.java @@ -0,0 +1,89 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.zvec.CollectionSchema; +import org.zvec.Zvec; +import org.zvec.ZvecSchemas; + +public abstract class AbstractZvecCreateAndOpenWithProviderTest { + + @Test + void writesSidecarBeforeNativeOpen(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("title").string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> "k1".equals(kid) ? new byte[32] : null; + + try (org.zvec.Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) { + assertTrue(Files.exists(path.resolve(SidecarMetadata.FILENAME))); + Optional meta = SidecarMetadata.read(path); + assertEquals("docs", meta.orElseThrow().collectionName()); + assertEquals(java.util.Set.of("body"), meta.get().encryptedFieldNames()); + } + } + + @Test + void rejectsNullProviderWhenSchemaHasEncryptedField(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + assertThrows(NullPointerException.class, + () -> Zvec.createAndOpen(path.toString(), schema, null)); + } + + @Test + void twoArgCreateAndOpenAcceptsAllStaticKey(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + byte[] key = new byte[32]; + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1", key) + .vector("e", 4).build(); + try (org.zvec.Collection col = Zvec.createAndOpen(path.toString(), schema)) { + assertEquals(java.util.Set.of("body"), col.encryptedSchema().encryptedFieldNames()); + } + } + + @Test + void twoArgCreateAndOpenRejectsKeyIdOnlyEncryption(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + assertThrows(EncryptedCollectionException.class, + () -> Zvec.createAndOpen(path.toString(), schema)); + } + + @Test + void twoArgCreateAndOpenAllowsNoEncryption(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("title").vector("e", 4).build(); + try (org.zvec.Collection col = Zvec.createAndOpen(path.toString(), schema)) { + assertEquals(java.util.Set.of(), col.encryptedSchema().encryptedFieldNames()); + } + } + + @Test + void threeArgCreateAndOpenInitializesNativeImplicitly(@TempDir Path tmp) { + // Smoke test: even though earlier tests in the same JVM may have initialized native, + // calling the encrypted entrypoint directly should not assume prior initialization. + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + try (org.zvec.Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) { + assertEquals(java.util.Set.of("body"), col.encryptedSchema().encryptedFieldNames()); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractZvecOpenWithKeysTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractZvecOpenWithKeysTest.java new file mode 100644 index 000000000..8f0f938a2 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AbstractZvecOpenWithKeysTest.java @@ -0,0 +1,42 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.file.Path; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.zvec.Collection; +import org.zvec.CollectionSchema; +import org.zvec.Zvec; +import org.zvec.ZvecSchemas; + +public abstract class AbstractZvecOpenWithKeysTest { + + @Test + void reopenWithKeysAttachesEncryption(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) { /* close */ } + + try (Collection col = Zvec.openWithKeys(path.toString(), provider)) { + EncryptedSchema es = col.encryptedSchema(); + assertEquals(java.util.Set.of("body"), es.encryptedFieldNames()); + } + } + + @Test + void openOnEncryptedCollectionWithoutProviderThrows(@TempDir Path tmp) { + Path path = tmp.resolve("docs"); + CollectionSchema schema = ZvecSchemas.collection("docs") + .string("body").encrypted("k1") + .vector("e", 4).build(); + KeyProvider provider = kid -> new byte[32]; + try (Collection col = Zvec.createAndOpen(path.toString(), schema, provider)) {} + + assertThrows(EncryptedCollectionException.class, () -> Zvec.open(path.toString())); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AesGcm256Test.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AesGcm256Test.java new file mode 100644 index 000000000..9e084b788 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/AesGcm256Test.java @@ -0,0 +1,131 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.security.SecureRandom; +import org.junit.jupiter.api.Test; + +class AesGcm256Test { + @Test + void roundTripsArbitraryPlaintext() { + Aead aead = new AesGcm256(); + byte[] key = new byte[32]; + new SecureRandom().nextBytes(key); + byte[] nonce = new byte[12]; + new SecureRandom().nextBytes(nonce); + byte[] plaintext = "hello world".getBytes(); + byte[] aad = "id1|body|docs".getBytes(); + + byte[] ct = aead.seal(key, nonce, plaintext, aad); + byte[] back = aead.open(key, nonce, ct, aad); + assertArrayEquals(plaintext, back); + } + + @Test + void emptyPlaintextIsSupported() { + Aead aead = new AesGcm256(); + byte[] key = new byte[32]; + byte[] nonce = new byte[12]; + byte[] ct = aead.seal(key, nonce, new byte[0], new byte[0]); + assertEquals(16, ct.length); // tag only + assertArrayEquals(new byte[0], aead.open(key, nonce, ct, new byte[0])); + } + + @Test + void tagFlipDetected() { + Aead aead = new AesGcm256(); + byte[] key = new byte[32]; + byte[] nonce = new byte[12]; + byte[] ct = aead.seal(key, nonce, "secret".getBytes(), "aad".getBytes()); + ct[ct.length - 1] ^= 0x01; + assertThrows(AuthenticationFailedException.class, + () -> aead.open(key, nonce, ct, "aad".getBytes())); + } + + @Test + void aadMismatchDetected() { + Aead aead = new AesGcm256(); + byte[] key = new byte[32]; + byte[] nonce = new byte[12]; + byte[] ct = aead.seal(key, nonce, "secret".getBytes(), "aad-A".getBytes()); + assertThrows(AuthenticationFailedException.class, + () -> aead.open(key, nonce, ct, "aad-B".getBytes())); + } + + @Test + void wrongKeyLengthRejected() { + Aead aead = new AesGcm256(); + byte[] shortKey = new byte[31]; + byte[] nonce = new byte[12]; + assertThrows(IllegalArgumentException.class, + () -> aead.seal(shortKey, nonce, new byte[1], new byte[0])); + } + + @Test + void wrongNonceLengthRejected() { + Aead aead = new AesGcm256(); + byte[] key = new byte[32]; + byte[] shortNonce = new byte[8]; + assertThrows(IllegalArgumentException.class, + () -> aead.seal(key, shortNonce, new byte[1], new byte[0])); + } + + @Test + void nistTestVector() { + // NIST GCM test vector: Key 256 bits, IV 96 bits, PT empty, AAD empty. + // From NIST CAVP gcmEncryptExtIV256.rsp, Count = 0 + byte[] key = parseHex("b52c505a37d78eda5dd34f20c22540ea1b58963cf8e5bf8ffa85f9f2492505b4"); + byte[] iv = parseHex("516c33929df5a3284ff463d7"); + byte[] expectedTag = parseHex("bdc1ac884d332457a1d2664f168c76f0"); + byte[] ct = new AesGcm256().seal(key, iv, new byte[0], new byte[0]); + assertArrayEquals(expectedTag, ct); + } + + @Test + void identifiesAsAesGcmInEnvelopeAlgByte() { + assertEquals(Envelope.ALG_AES_256_GCM, new AesGcm256().algId()); + } + + @Test + void nonceUniquenessOver1k() { + java.util.Set seen = new java.util.HashSet<>(); + SecureRandom rng = new SecureRandom(); + for (int i = 0; i < 1000; i++) { + byte[] n = new byte[12]; + rng.nextBytes(n); + seen.add(formatHex(n)); + } + assertNotEquals(0, seen.size()); + org.junit.jupiter.api.Assertions.assertTrue(seen.size() >= 999); + } + + private static byte[] parseHex(String hex) { + if ((hex.length() & 1) != 0) { + throw new IllegalArgumentException("hex length must be even"); + } + byte[] out = new byte[hex.length() / 2]; + for (int i = 0; i < out.length; i++) { + int hi = Character.digit(hex.charAt(i * 2), 16); + int lo = Character.digit(hex.charAt(i * 2 + 1), 16); + if (hi < 0 || lo < 0) { + throw new IllegalArgumentException("invalid hex: " + hex); + } + out[i] = (byte) ((hi << 4) | lo); + } + return out; + } + + private static String formatHex(byte[] bytes) { + char[] out = new char[bytes.length * 2]; + char[] digits = "0123456789abcdef".toCharArray(); + for (int i = 0; i < bytes.length; i++) { + int value = bytes[i] & 0xff; + out[i * 2] = digits[value >>> 4]; + out[i * 2 + 1] = digits[value & 0x0f]; + } + return new String(out); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/DecryptingProjectorTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/DecryptingProjectorTest.java new file mode 100644 index 000000000..9fe3a8822 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/DecryptingProjectorTest.java @@ -0,0 +1,106 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.zvec.CollectionSchema; +import org.zvec.DataType; +import org.zvec.Doc; +import org.zvec.FieldSchema; +import org.zvec.VectorSchema; + +class DecryptingProjectorTest { + private static final byte[] KEY = new byte[32]; + static { KEY[0] = 9; } + + private static EncryptedSchema buildEncSchema() { + CollectionSchema cs = new CollectionSchema( + "docs", + List.of(new FieldSchema("title", DataType.STRING, false), + new FieldSchema("body", DataType.STRING, false)), + List.of(new VectorSchema("embed", DataType.VECTOR_FP32, 4))); + EncryptionSpec spec = new EncryptionSpec("AES-256-GCM", "k1", Instant.now(), null); + EncryptionMetadata meta = new EncryptionMetadata(1, "docs", Map.of("body", spec)); + return EncryptedSchema.reconcile(cs, meta, kid -> "k1".equals(kid) ? KEY : null); + } + + @Test + void roundTripsViaInsertor() { + EncryptedSchema es = buildEncSchema(); + Doc plain = Doc.of("d1").field("title", "alpha").field("body", "secret-text"); + Doc encrypted = EncryptingInsertor.transform(List.of(plain), es).get(0); + + Doc relayed = Doc.of("d1"); + encrypted.fields().forEach((k, v) -> { if (v instanceof String) relayed.field(k, (String) v); }); + + Doc decrypted = DecryptingProjector.transform(List.of(relayed), es).get(0); + assertEquals("alpha", decrypted.fields().get("title")); + assertEquals("secret-text", decrypted.fields().get("body")); + } + + @Test + void noneSentinelReturnsInputUnchanged() { + Doc input = Doc.of("d1").field("body", "anything"); + List out = DecryptingProjector.transform(List.of(input), EncryptedSchema.NONE); + assertSame(input, out.get(0)); + } + + @Test + void missingEncryptedFieldOnResultIsTolerated() { + EncryptedSchema es = buildEncSchema(); + Doc input = Doc.of("d1").field("title", "alpha"); + Doc out = DecryptingProjector.transform(List.of(input), es).get(0); + assertEquals("alpha", out.fields().get("title")); + assertEquals(false, out.fields().containsKey("body")); + } + + @Test + void aadMismatchSurfacesAsAuthenticationFailed() { + EncryptedSchema es = buildEncSchema(); + Doc plain = Doc.of("d1").field("body", "secret"); + Doc encrypted = EncryptingInsertor.transform(List.of(plain), es).get(0); + + Doc relocated = Doc.of("d2"); + encrypted.fields().forEach((k, v) -> { if (v instanceof String) relocated.field(k, (String) v); }); + + assertThrows(AuthenticationFailedException.class, + () -> DecryptingProjector.transform(List.of(relocated), es)); + } + + @Test + void unknownKeyIdSurfacesAsKeyResolution() { + CollectionSchema cs = new CollectionSchema( + "docs", + List.of(new FieldSchema("body", DataType.STRING, false)), + List.of(new VectorSchema("embed", DataType.VECTOR_FP32, 4))); + EncryptionSpec spec = new EncryptionSpec("AES-256-GCM", "k1", Instant.now(), null); + EncryptionMetadata meta = new EncryptionMetadata(1, "docs", Map.of("body", spec)); + EncryptedSchema esEncrypt = EncryptedSchema.reconcile(cs, meta, kid -> KEY); + Doc plain = Doc.of("d1").field("body", "secret"); + Doc encrypted = EncryptingInsertor.transform(List.of(plain), esEncrypt).get(0); + + EncryptedSchema esDecryptNoKey = EncryptedSchema.reconcile(cs, meta, kid -> null); + Doc relayed = Doc.of("d1"); + encrypted.fields().forEach((k, v) -> { if (v instanceof String) relayed.field(k, (String) v); }); + assertThrows(KeyResolutionException.class, + () -> DecryptingProjector.transform(List.of(relayed), esDecryptNoKey)); + } + + @Test + void unknownVersionSurfacesAsEnvelopeFormat() { + EncryptedSchema es = buildEncSchema(); + byte[] junk = new byte[200]; + junk[0] = 0x09; + String b64 = java.util.Base64.getUrlEncoder().withoutPadding().encodeToString(junk); + Doc bad = Doc.of("d1").field("body", b64); + assertThrows(EnvelopeFormatException.class, + () -> DecryptingProjector.transform(List.of(bad), es)); + assertNotNull(es); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptedSchemaTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptedSchemaTest.java new file mode 100644 index 000000000..ff9d905be --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptedSchemaTest.java @@ -0,0 +1,81 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.zvec.CollectionSchema; +import org.zvec.DataType; +import org.zvec.FieldSchema; +import org.zvec.VectorSchema; + +class EncryptedSchemaTest { + private static CollectionSchema docsSchema() { + return new CollectionSchema( + "docs", + List.of(new FieldSchema("title", DataType.STRING, false), + new FieldSchema("body", DataType.STRING, false)), + List.of(new VectorSchema("embed", DataType.VECTOR_FP32, 4))); + } + + private static EncryptionMetadata metaFor(String collName, String fieldName, String keyId) { + EncryptionSpec spec = new EncryptionSpec("AES-256-GCM", keyId, Instant.now(), null); + return new EncryptionMetadata(1, collName, Map.of(fieldName, spec)); + } + + @Test + void reconcileSucceedsForMatchingFields() { + EncryptedSchema es = EncryptedSchema.reconcile( + docsSchema(), + metaFor("docs", "body", "k1"), + keyId -> new byte[32]); + assertNotNull(es); + assertTrue(es.isEncrypted("body")); + assertEquals("k1", es.activeKeyId("body")); + } + + @Test + void reconcileFailsWhenCollectionNamesDiffer() { + assertThrows(EncryptionMetadataMismatchException.class, () -> + EncryptedSchema.reconcile(docsSchema(), + metaFor("OTHER", "body", "k1"), keyId -> new byte[32])); + } + + @Test + void reconcileFailsWhenFieldMissing() { + assertThrows(EncryptionMetadataMismatchException.class, () -> + EncryptedSchema.reconcile(docsSchema(), + metaFor("docs", "missing", "k1"), keyId -> new byte[32])); + } + + @Test + void reconcileFailsWhenFieldNotString() { + CollectionSchema schema = new CollectionSchema( + "docs", + List.of(new FieldSchema("salary", DataType.INT64, false)), + List.of(new VectorSchema("embed", DataType.VECTOR_FP32, 4))); + assertThrows(EncryptionMetadataMismatchException.class, () -> + EncryptedSchema.reconcile(schema, + metaFor("docs", "salary", "k1"), keyId -> new byte[32])); + } + + @Test + void reconcileEmptyMetadataReturnsNoneSentinel() { + EncryptedSchema es = EncryptedSchema.reconcile( + docsSchema(), EncryptionMetadata.empty("docs"), keyId -> new byte[32]); + assertSame(EncryptedSchema.NONE, es); + assertEquals(false, es.isEncrypted("body")); + } + + @Test + void noneSentinelHasNoEncryptedFields() { + assertTrue(EncryptedSchema.NONE.encryptedFieldNames().isEmpty()); + assertEquals(false, EncryptedSchema.NONE.isEncrypted("any")); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptingInsertorTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptingInsertorTest.java new file mode 100644 index 000000000..d26f29ba5 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptingInsertorTest.java @@ -0,0 +1,99 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.zvec.CollectionSchema; +import org.zvec.DataType; +import org.zvec.Doc; +import org.zvec.FieldSchema; +import org.zvec.VectorSchema; + +class EncryptingInsertorTest { + private static final byte[] KEY = new byte[32]; + + private static EncryptedSchema buildEncSchema(String fieldName) { + CollectionSchema cs = new CollectionSchema( + "docs", + List.of(new FieldSchema("title", DataType.STRING, false), + new FieldSchema(fieldName, DataType.STRING, false)), + List.of(new VectorSchema("embed", DataType.VECTOR_FP32, 4))); + EncryptionSpec spec = new EncryptionSpec("AES-256-GCM", "k1", Instant.now(), null); + EncryptionMetadata meta = new EncryptionMetadata(1, "docs", Map.of(fieldName, spec)); + return EncryptedSchema.reconcile(cs, meta, kid -> "k1".equals(kid) ? KEY : null); + } + + @Test + void encryptsOnlyMarkedField() { + EncryptedSchema es = buildEncSchema("body"); + Doc input = Doc.of("d1").field("title", "alpha").field("body", "secret") + .vector("embed", new float[] {1f,0f,0f,0f}); + + List out = EncryptingInsertor.transform(List.of(input), es); + + assertEquals("alpha", out.get(0).fields().get("title")); + assertNotEquals("secret", out.get(0).fields().get("body")); + assertArrayEquals(new float[] {1f,0f,0f,0f}, out.get(0).vectors().get("embed")); + } + + @Test + void inputDocIsNotMutated() { + EncryptedSchema es = buildEncSchema("body"); + Doc input = Doc.of("d1").field("body", "secret"); + EncryptingInsertor.transform(List.of(input), es); + assertEquals("secret", input.fields().get("body")); + } + + @Test + void noEncryptedFieldOnDocIsTolerated() { + EncryptedSchema es = buildEncSchema("body"); + Doc input = Doc.of("d1").field("title", "alpha"); + List out = EncryptingInsertor.transform(List.of(input), es); + assertEquals("alpha", out.get(0).fields().get("title")); + assertEquals(false, out.get(0).fields().containsKey("body")); + } + + @Test + void noneSentinelReturnsInputUnchanged() { + Doc input = Doc.of("d1").field("body", "x"); + List out = EncryptingInsertor.transform(List.of(input), EncryptedSchema.NONE); + assertEquals(input.fields(), out.get(0).fields()); + } + + @Test + void roundTripsViaCodecAndAead() { + EncryptedSchema es = buildEncSchema("body"); + Doc input = Doc.of("d1").field("body", "the quick brown fox"); + Doc encrypted = EncryptingInsertor.transform(List.of(input), es).get(0); + + String b64 = (String) encrypted.fields().get("body"); + Envelope env = EnvelopeCodec.decodeBase64(b64); + assertEquals("k1", env.keyId()); + byte[] aad = AadEncoder.encode("d1", "body", "docs"); + byte[] plaintext = new AesGcm256().open(KEY, env.nonce(), env.ciphertext(), aad); + assertEquals("the quick brown fox", new String(plaintext, StandardCharsets.UTF_8)); + } + + @Test + void providerNullThrows() { + CollectionSchema cs = new CollectionSchema( + "docs", + List.of(new FieldSchema("body", DataType.STRING, false)), + List.of(new VectorSchema("embed", DataType.VECTOR_FP32, 4))); + EncryptionSpec spec = new EncryptionSpec("AES-256-GCM", "missing-key", Instant.now(), null); + EncryptionMetadata meta = new EncryptionMetadata(1, "docs", Map.of("body", spec)); + EncryptedSchema es = EncryptedSchema.reconcile(cs, meta, kid -> null); + Doc input = Doc.of("d1").field("body", "secret"); + KeyResolutionException e = assertThrows(KeyResolutionException.class, + () -> EncryptingInsertor.transform(List.of(input), es)); + assertTrue(e.getMessage().contains("missing-key")); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptionExceptionTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptionExceptionTest.java new file mode 100644 index 000000000..8657e8779 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptionExceptionTest.java @@ -0,0 +1,33 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +class EncryptionExceptionTest { + @Test + void hierarchyIsCorrect() { + assertTrue(EncryptionException.class.isAssignableFrom(EncryptionConfigException.class)); + assertTrue(EncryptionException.class.isAssignableFrom(EncryptionRuntimeException.class)); + assertTrue(EncryptionConfigException.class.isAssignableFrom(EncryptedCollectionException.class)); + assertTrue(EncryptionConfigException.class.isAssignableFrom(EncryptionMetadataMismatchException.class)); + assertTrue(EncryptionConfigException.class.isAssignableFrom(EncryptionMetadataIOException.class)); + assertTrue(EncryptionConfigException.class.isAssignableFrom(UnsupportedFieldTypeException.class)); + assertTrue(EncryptionRuntimeException.class.isAssignableFrom(KeyResolutionException.class)); + assertTrue(EncryptionRuntimeException.class.isAssignableFrom(EncryptionFailedException.class)); + assertTrue(EncryptionRuntimeException.class.isAssignableFrom(DecryptionException.class)); + assertTrue(DecryptionException.class.isAssignableFrom(EnvelopeFormatException.class)); + assertTrue(DecryptionException.class.isAssignableFrom(AuthenticationFailedException.class)); + assertTrue(RuntimeException.class.isAssignableFrom(EncryptionException.class)); + } + + @Test + void messagesAndCausesPropagate() { + Throwable cause = new IllegalStateException("inner"); + KeyResolutionException e = new KeyResolutionException("provider returned null for keyId='k1'", cause); + assertEquals("provider returned null for keyId='k1'", e.getMessage()); + assertSame(cause, e.getCause()); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptionMetadataTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptionMetadataTest.java new file mode 100644 index 000000000..1ae6e11a7 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EncryptionMetadataTest.java @@ -0,0 +1,56 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.time.Instant; +import java.util.Map; +import org.junit.jupiter.api.Test; + +class EncryptionMetadataTest { + @Test + void specFieldsRetained() { + EncryptionSpec s = new EncryptionSpec("AES-256-GCM", "k1", Instant.parse("2026-04-28T00:00:00Z"), null); + assertEquals("AES-256-GCM", s.alg()); + assertEquals("k1", s.activeKeyId()); + assertEquals(Instant.parse("2026-04-28T00:00:00Z"), s.createdAt()); + org.junit.jupiter.api.Assertions.assertNull(s.rotatedAt()); + } + + @Test + void specRejectsBadAlg() { + assertThrows(IllegalArgumentException.class, + () -> new EncryptionSpec("AES-128-CBC", "k1", Instant.now(), null)); + } + + @Test + void specRejectsEmptyKeyId() { + assertThrows(IllegalArgumentException.class, + () -> new EncryptionSpec("AES-256-GCM", "", Instant.now(), null)); + } + + @Test + void metadataExposesEncryptedFieldNames() { + EncryptionSpec spec = new EncryptionSpec("AES-256-GCM", "k1", Instant.now(), null); + EncryptionMetadata meta = new EncryptionMetadata(1, "docs", Map.of("body", spec, "ssn", spec)); + assertEquals(java.util.Set.of("body", "ssn"), meta.encryptedFieldNames()); + assertSame(spec, meta.spec("body")); + assertTrue(meta.isEncrypted("body")); + assertEquals(false, meta.isEncrypted("title")); + } + + @Test + void emptyMetadataConstant() { + EncryptionMetadata empty = EncryptionMetadata.empty("docs"); + assertTrue(empty.encryptedFieldNames().isEmpty()); + assertEquals("docs", empty.collectionName()); + } + + @Test + void rejectsUnknownVersion() { + assertThrows(IllegalArgumentException.class, + () -> new EncryptionMetadata(2, "docs", Map.of())); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EnvelopeCodecTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EnvelopeCodecTest.java new file mode 100644 index 000000000..38777f2db --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EnvelopeCodecTest.java @@ -0,0 +1,95 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Test; + +class EnvelopeCodecTest { + private static final byte[] NONCE = new byte[] {1,2,3,4,5,6,7,8,9,10,11,12}; + + @Test + void roundTripBinary() { + Envelope original = new Envelope( + Envelope.VERSION_V1, Envelope.ALG_AES_256_GCM, Envelope.PAYLOAD_STRING, + "key-1", NONCE, new byte[] {(byte)0xde, (byte)0xad, (byte)0xbe, (byte)0xef}); + byte[] encoded = EnvelopeCodec.encode(original); + Envelope decoded = EnvelopeCodec.decode(encoded); + + assertEquals(original.version(), decoded.version()); + assertEquals(original.alg(), decoded.alg()); + assertEquals(original.payloadType(), decoded.payloadType()); + assertEquals(original.keyId(), decoded.keyId()); + assertArrayEquals(original.nonce(), decoded.nonce()); + assertArrayEquals(original.ciphertext(), decoded.ciphertext()); + } + + @Test + void roundTripBase64UrlNoPad() { + Envelope original = new Envelope( + Envelope.VERSION_V1, Envelope.ALG_AES_256_GCM, Envelope.PAYLOAD_STRING, + "k", NONCE, new byte[] {1, 2, 3}); + String b64 = EnvelopeCodec.encodeBase64(original); + org.junit.jupiter.api.Assertions.assertFalse(b64.contains("=")); + org.junit.jupiter.api.Assertions.assertFalse(b64.contains("+")); + org.junit.jupiter.api.Assertions.assertFalse(b64.contains("/")); + Envelope decoded = EnvelopeCodec.decodeBase64(b64); + assertArrayEquals(original.ciphertext(), decoded.ciphertext()); + } + + @Test + void rejectsUnknownVersion() { + byte[] bad = new byte[] {(byte)0x02, 0x01, 0x00, 1, 'k', 1,2,3,4,5,6,7,8,9,10,11,12, 0,0,0}; + EnvelopeFormatException e = assertThrows(EnvelopeFormatException.class, () -> EnvelopeCodec.decode(bad)); + org.junit.jupiter.api.Assertions.assertTrue(e.getMessage().contains("version")); + } + + @Test + void rejectsZeroKeyIdLen() { + byte[] bad = new byte[] {(byte)0x01, 0x01, 0x00, 0, 1,2,3,4,5,6,7,8,9,10,11,12, 0,0,0}; + assertThrows(EnvelopeFormatException.class, () -> EnvelopeCodec.decode(bad)); + } + + @Test + void rejectsTruncatedBuffer() { + byte[] tooShort = new byte[] {0x01, 0x01, 0x00, 1}; + assertThrows(EnvelopeFormatException.class, () -> EnvelopeCodec.decode(tooShort)); + } + + @Test + void rejectsKeyIdLongerThan255() { + String kid = "k".repeat(256); + Envelope env = new Envelope( + Envelope.VERSION_V1, Envelope.ALG_AES_256_GCM, Envelope.PAYLOAD_STRING, + kid, NONCE, new byte[] {1}); + assertThrows(IllegalArgumentException.class, () -> EnvelopeCodec.encode(env)); + } + + @Test + void rejectsWrongNonceSize() { + byte[] shortNonce = new byte[8]; + Envelope env = new Envelope( + Envelope.VERSION_V1, Envelope.ALG_AES_256_GCM, Envelope.PAYLOAD_STRING, + "k", shortNonce, new byte[] {1}); + assertThrows(IllegalArgumentException.class, () -> EnvelopeCodec.encode(env)); + } + + @Test + void layoutMatchesSpec() { + String kid = "kid"; + Envelope env = new Envelope( + Envelope.VERSION_V1, Envelope.ALG_AES_256_GCM, Envelope.PAYLOAD_STRING, + kid, NONCE, new byte[] {(byte)0xff}); + byte[] enc = EnvelopeCodec.encode(env); + assertEquals(0x01, enc[0] & 0xff); + assertEquals(0x01, enc[1] & 0xff); + assertEquals(0x00, enc[2] & 0xff); + assertEquals(kid.length(), enc[3] & 0xff); + assertArrayEquals(kid.getBytes(StandardCharsets.UTF_8), + java.util.Arrays.copyOfRange(enc, 4, 4 + kid.length())); + assertArrayEquals(NONCE, + java.util.Arrays.copyOfRange(enc, 4 + kid.length(), 4 + kid.length() + 12)); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EnvelopeRelocationSecurityTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EnvelopeRelocationSecurityTest.java new file mode 100644 index 000000000..cee807db8 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/EnvelopeRelocationSecurityTest.java @@ -0,0 +1,85 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.zvec.CollectionSchema; +import org.zvec.DataType; +import org.zvec.Doc; +import org.zvec.FieldSchema; +import org.zvec.VectorSchema; + +class EnvelopeRelocationSecurityTest { + + private static final byte[] KEY = new byte[32]; + static { KEY[0] = 33; } + + private static EncryptedSchema encryptedSchema(String collName, String fieldName) { + CollectionSchema cs = new CollectionSchema( + collName, + List.of(new FieldSchema("title", DataType.STRING, false), + new FieldSchema(fieldName, DataType.STRING, false), + new FieldSchema("other", DataType.STRING, false)), + List.of(new VectorSchema("e", DataType.VECTOR_FP32, 4))); + EncryptionSpec spec = new EncryptionSpec("AES-256-GCM", "k1", Instant.now(), null); + EncryptionMetadata meta = new EncryptionMetadata(1, collName, Map.of(fieldName, spec, "other", spec)); + return EncryptedSchema.reconcile(cs, meta, kid -> "k1".equals(kid) ? KEY : null); + } + + @Test + void relocationAcrossDocsRejected() { + EncryptedSchema es = encryptedSchema("docs", "body"); + Doc d1 = EncryptingInsertor.transform( + List.of(Doc.of("d1").field("body", "secret-1")), es).get(0); + + Doc relocated = Doc.of("d2"); + relocated.field("body", (String) d1.fields().get("body")); + assertThrows(AuthenticationFailedException.class, + () -> DecryptingProjector.transform(List.of(relocated), es)); + } + + @Test + void relocationAcrossFieldsRejected() { + EncryptedSchema es = encryptedSchema("docs", "body"); + Doc d1 = EncryptingInsertor.transform( + List.of(Doc.of("d1").field("body", "secret-1")), es).get(0); + + Doc swapped = Doc.of("d1").field("other", (String) d1.fields().get("body")); + assertThrows(AuthenticationFailedException.class, + () -> DecryptingProjector.transform(List.of(swapped), es)); + } + + @Test + void relocationAcrossCollectionsRejected() { + EncryptedSchema esA = encryptedSchema("docsA", "body"); + EncryptedSchema esB = encryptedSchema("docsB", "body"); + Doc d1 = EncryptingInsertor.transform( + List.of(Doc.of("d1").field("body", "secret")), esA).get(0); + + Doc moved = Doc.of("d1").field("body", (String) d1.fields().get("body")); + assertThrows(AuthenticationFailedException.class, + () -> DecryptingProjector.transform(List.of(moved), esB)); + } + + @Test + void singleByteFlipAnywhereInEnvelopeRejected() { + EncryptedSchema es = encryptedSchema("docs", "body"); + Doc d1 = EncryptingInsertor.transform( + List.of(Doc.of("d1").field("body", "secret-payload")), es).get(0); + String b64 = (String) d1.fields().get("body"); + byte[] raw = java.util.Base64.getUrlDecoder().decode(b64); + + for (int i = 0; i < raw.length; i++) { + byte[] copy = raw.clone(); + copy[i] ^= 0x01; + String corrupted = java.util.Base64.getUrlEncoder().withoutPadding().encodeToString(copy); + Doc bad = Doc.of("d1").field("body", corrupted); + assertThrows(EncryptionException.class, + () -> DecryptingProjector.transform(List.of(bad), es), + "byte index " + i + " flip should have failed"); + } + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/FilterFieldScannerTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/FilterFieldScannerTest.java new file mode 100644 index 000000000..2f1ac6f19 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/FilterFieldScannerTest.java @@ -0,0 +1,51 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Set; +import org.junit.jupiter.api.Test; + +class FilterFieldScannerTest { + @Test + void extractsBareIdentifiers() { + Set ids = FilterFieldScanner.referencedFields("body = 'foo'"); + assertEquals(Set.of("body"), ids); + } + + @Test + void ignoresStringLiterals() { + Set ids = FilterFieldScanner.referencedFields("title = 'body'"); + assertEquals(Set.of("title"), ids); + } + + @Test + void ignoresDoubleQuotedLiterals() { + Set ids = FilterFieldScanner.referencedFields("title = \"body\""); + assertEquals(Set.of("title"), ids); + } + + @Test + void distinguishesIdentifierFromSubstring() { + Set ids = FilterFieldScanner.referencedFields("bodyguard = 'x'"); + assertEquals(Set.of("bodyguard"), ids); + } + + @Test + void compoundExpressions() { + Set ids = FilterFieldScanner.referencedFields("body = 'foo' AND title != 'bar' OR rank > 10"); + assertEquals(Set.of("body", "title", "rank", "AND", "OR"), ids); + } + + @Test + void emptyFilter() { + assertTrue(FilterFieldScanner.referencedFields("").isEmpty()); + assertTrue(FilterFieldScanner.referencedFields(null).isEmpty()); + } + + @Test + void escapedQuoteDoesNotEndLiteral() { + Set ids = FilterFieldScanner.referencedFields("title = 'O\\'Brien' AND body LIKE 'x'"); + assertEquals(Set.of("title", "AND", "body", "LIKE"), ids); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SidecarJsonTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SidecarJsonTest.java new file mode 100644 index 000000000..d94717b80 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SidecarJsonTest.java @@ -0,0 +1,69 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.time.Instant; +import java.util.Map; +import org.junit.jupiter.api.Test; + +class SidecarJsonTest { + @Test + void roundTripFullMetadata() { + EncryptionSpec spec = new EncryptionSpec( + "AES-256-GCM", "body-key-v1", + Instant.parse("2026-04-28T03:14:15Z"), + Instant.parse("2026-04-28T05:00:00Z")); + EncryptionMetadata original = new EncryptionMetadata(1, "docs", Map.of("body", spec)); + + String json = SidecarJson.write(original); + EncryptionMetadata back = SidecarJson.read(json); + + assertEquals(original, back); + } + + @Test + void roundTripWithoutRotatedAt() { + EncryptionSpec spec = new EncryptionSpec( + "AES-256-GCM", "k1", Instant.parse("2026-04-28T00:00:00Z"), null); + EncryptionMetadata original = new EncryptionMetadata(1, "c", Map.of("f", spec)); + + String json = SidecarJson.write(original); + EncryptionMetadata back = SidecarJson.read(json); + assertEquals(original, back); + } + + @Test + void roundTripEmptyFields() { + EncryptionMetadata original = EncryptionMetadata.empty("c"); + EncryptionMetadata back = SidecarJson.read(SidecarJson.write(original)); + assertEquals(original, back); + } + + @Test + void rejectsMalformed() { + assertThrows(EncryptionMetadataIOException.class, () -> SidecarJson.read("not json")); + assertThrows(EncryptionMetadataIOException.class, () -> SidecarJson.read("{")); + assertThrows(EncryptionMetadataIOException.class, () -> SidecarJson.read("{\"version\":\"1\"}")); + } + + @Test + void rejectsUnknownVersion() { + String s = "{\"version\":2,\"collection_name\":\"c\",\"fields\":{}}"; + assertThrows(IllegalArgumentException.class, () -> SidecarJson.read(s)); + } + + @Test + void rejectsTrailingGarbage() { + String good = "{\"version\":1,\"collection_name\":\"c\",\"fields\":{}}"; + String bad = good + "extra"; + assertThrows(EncryptionMetadataIOException.class, () -> SidecarJson.read(bad)); + } + + @Test + void allowsTrailingWhitespace() { + String s = "{\"version\":1,\"collection_name\":\"c\",\"fields\":{}}\n \n"; + EncryptionMetadata back = SidecarJson.read(s); + org.junit.jupiter.api.Assertions.assertEquals("c", back.collectionName()); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SidecarMetadataTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SidecarMetadataTest.java new file mode 100644 index 000000000..591d1563a --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SidecarMetadataTest.java @@ -0,0 +1,56 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Instant; +import java.util.Map; +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +class SidecarMetadataTest { + @Test + void readReturnsEmptyWhenSidecarAbsent(@TempDir Path dir) { + Optional result = SidecarMetadata.read(dir); + assertTrue(result.isEmpty()); + } + + @Test + void writeThenRead(@TempDir Path dir) { + EncryptionSpec spec = new EncryptionSpec("AES-256-GCM", "k1", Instant.parse("2026-04-28T00:00:00Z"), null); + EncryptionMetadata meta = new EncryptionMetadata(1, "docs", Map.of("body", spec)); + SidecarMetadata.write(dir, meta); + + Optional back = SidecarMetadata.read(dir); + assertTrue(back.isPresent()); + assertEquals(meta, back.get()); + assertTrue(Files.exists(dir.resolve("_zvec_enc.json"))); + } + + @Test + void atomicReplaceLeavesNoTempFile(@TempDir Path dir) throws IOException { + EncryptionSpec spec = new EncryptionSpec("AES-256-GCM", "k1", Instant.now(), null); + EncryptionMetadata first = new EncryptionMetadata(1, "docs", Map.of("body", spec)); + SidecarMetadata.write(dir, first); + EncryptionSpec spec2 = new EncryptionSpec("AES-256-GCM", "k2", Instant.now(), Instant.now()); + EncryptionMetadata second = new EncryptionMetadata(1, "docs", Map.of("body", spec2)); + SidecarMetadata.write(dir, second); + + assertTrue(Files.exists(dir.resolve("_zvec_enc.json"))); + try (var stream = Files.list(dir)) { + assertEquals(0L, stream.filter(p -> p.getFileName().toString().endsWith(".tmp")).count()); + } + assertEquals(second, SidecarMetadata.read(dir).orElseThrow()); + } + + @Test + void corruptedFileSurfacesAsIOException(@TempDir Path dir) throws IOException { + Files.writeString(dir.resolve("_zvec_enc.json"), "garbage}not}json"); + assertThrows(EncryptionMetadataIOException.class, () -> SidecarMetadata.read(dir)); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SingletonKeyProviderTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SingletonKeyProviderTest.java new file mode 100644 index 000000000..298a4bfb8 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/crypto/SingletonKeyProviderTest.java @@ -0,0 +1,49 @@ +package org.zvec.crypto; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +class SingletonKeyProviderTest { + @Test + void resolvesMatchingKeyId() { + byte[] key = new byte[32]; + KeyProvider p = new SingletonKeyProvider("k1", key); + assertArrayEquals(key, p.resolve("k1")); + } + + @Test + void returnsNullForOtherKeyId() { + KeyProvider p = new SingletonKeyProvider("k1", new byte[32]); + assertNull(p.resolve("k2")); + } + + @Test + void rejectsNonAes256KeyLength() { + assertThrows(IllegalArgumentException.class, () -> new SingletonKeyProvider("k1", new byte[31])); + assertThrows(IllegalArgumentException.class, () -> new SingletonKeyProvider("k1", new byte[33])); + } + + @Test + void rejectsEmptyKeyId() { + assertThrows(IllegalArgumentException.class, () -> new SingletonKeyProvider("", new byte[32])); + } + + @Test + void defensiveCopy() { + byte[] original = new byte[32]; + original[0] = 1; + SingletonKeyProvider p = new SingletonKeyProvider("k1", original); + original[0] = 99; + org.junit.jupiter.api.Assertions.assertEquals(1, p.resolve("k1")[0]); + } + + @Test + void isActiveDefaultsTrue() { + KeyProvider p = new SingletonKeyProvider("k1", new byte[32]); + assertTrue(p.isActive("k1")); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/internal/NativeBackendsTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/internal/NativeBackendsTest.java new file mode 100644 index 000000000..61ed8d0ac --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/internal/NativeBackendsTest.java @@ -0,0 +1,117 @@ +package org.zvec.internal; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.VectorQuery; +import org.junit.jupiter.api.Test; + +class NativeBackendsTest { + @Test + void rejectsMissingProviders() { + IllegalStateException ex = + assertThrows(IllegalStateException.class, () -> NativeBackends.resolve(List.of(), null)); + + assertEquals( + "No zvec native backend found. Add exactly one backend dependency: zvec-java-jni or zvec-java-ffm.", + ex.getMessage()); + } + + @Test + void acceptsSingleProvider() { + NativeBackend backend = backend("jni"); + + assertEquals(backend, NativeBackends.resolve(List.of(provider(backend)), null)); + } + + @Test + void rejectsMultipleProvidersWithoutSelection() { + IllegalStateException ex = + assertThrows( + IllegalStateException.class, + () -> NativeBackends.resolve(List.of(provider(backend("jni")), provider(backend("ffm"))), null)); + + assertEquals( + "Multiple zvec native backends found: jni, ffm. Set -Dorg.zvec.backend=jni or -Dorg.zvec.backend=ffm.", + ex.getMessage()); + } + + @Test + void selectsExplicitProvider() { + NativeBackend jni = backend("jni"); + NativeBackend ffm = backend("ffm"); + + assertEquals(ffm, NativeBackends.resolve(List.of(provider(jni), provider(ffm)), "ffm")); + } + + @Test + void rejectsUnknownExplicitProvider() { + IllegalStateException ex = + assertThrows( + IllegalStateException.class, + () -> NativeBackends.resolve(List.of(provider(backend("jni"))), "ffm")); + + assertEquals( + "Requested zvec native backend 'ffm' was not found. Available backends: jni.", + ex.getMessage()); + } + + private static NativeBackendProvider provider(NativeBackend backend) { + return () -> backend; + } + + private static NativeBackend backend(String id) { + return new NativeBackend() { + @Override + public String id() { + return id; + } + + @Override + public String version() { + return "test"; + } + + @Override + public void ensureInitialized() {} + + @Override + public NativeOpenResult open(String path) { + throw new UnsupportedOperationException(); + } + + @Override + public NativeOpenResult createAndOpen(String path, CollectionSchema schema) { + throw new UnsupportedOperationException(); + } + + @Override + public void close(NativeHandle handle) {} + + @Override + public void flush(NativeHandle handle) {} + + @Override + public CollectionSchema readSchema(NativeHandle handle) { + throw new UnsupportedOperationException(); + } + + @Override + public int insert(NativeHandle handle, CollectionSchema schema, java.util.List docs) { + throw new UnsupportedOperationException(); + } + + @Override + public java.util.List query( + NativeHandle handle, + CollectionSchema querySchema, + CollectionSchema resultSchema, + VectorQuery query) { + throw new UnsupportedOperationException(); + } + }; + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/AbstractCollectionConcurrentStressMainTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/AbstractCollectionConcurrentStressMainTest.java new file mode 100644 index 000000000..05b43af40 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/AbstractCollectionConcurrentStressMainTest.java @@ -0,0 +1,66 @@ +package org.zvec.perf; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public abstract class AbstractCollectionConcurrentStressMainTest { + @TempDir Path tempDir; + + @Test + void printsConcurrentStressSummaryForSmallRun() throws Exception { + assumeSupportedPlatform(); + + PrintStream originalOut = System.out; + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + + try { + System.setOut(new PrintStream(buffer, true, StandardCharsets.UTF_8)); + CollectionConcurrentStressMain.main( + new String[] { + "--docs", + "200", + "--concurrent-query-threads", + "2", + "--concurrent-query-count", + "10", + "--concurrent-mixed-threads", + "2", + "--concurrent-mixed-rounds", + "2", + "--concurrent-mixed-insert-batch-size", + "5", + "--concurrent-mixed-queries-per-round", + "8", + "--work-dir", + tempDir.resolve("concurrent").toString() + }); + } finally { + System.setOut(originalOut); + } + + String output = buffer.toString(StandardCharsets.UTF_8); + assertTrue(output.contains("CONCURRENT_STRESS_CONFIG docs=200")); + assertTrue(output.contains("QUERY_ONLY_SUMMARY")); + assertTrue(output.contains("queries_per_sec=")); + assertTrue(output.contains("MIXED_SUMMARY")); + assertTrue(output.contains("insert_docs_per_sec=")); + assertTrue(output.contains("query_failures=")); + assertTrue(output.contains("insert_failures=")); + assertTrue(output.contains("ARTIFACT_DIR " + tempDir.resolve("concurrent"))); + assertTrue(output.contains("exec:java@run-concurrent-stress")); + } + + private static void assumeSupportedPlatform() { + String osName = System.getProperty("os.name", "").toLowerCase(); + String osArch = System.getProperty("os.arch", "").toLowerCase(); + Assumptions.assumeTrue(osName.contains("mac")); + Assumptions.assumeTrue(osArch.equals("aarch64") || osArch.equals("arm64")); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/AbstractCollectionStressMainTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/AbstractCollectionStressMainTest.java new file mode 100644 index 000000000..611afd31c --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/AbstractCollectionStressMainTest.java @@ -0,0 +1,69 @@ +package org.zvec.perf; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public abstract class AbstractCollectionStressMainTest { + @TempDir Path tempDir; + + @Test + void printsStressSummaryForSmallRun() throws Exception { + assumeSupportedPlatform(); + + PrintStream originalOut = System.out; + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + + try { + System.setOut(new PrintStream(buffer, true, StandardCharsets.UTF_8)); + CollectionStressMain.main( + new String[] { + "--docs", + "200", + "--queries", + "20", + "--batch-size", + "50", + "--warmup-queries", + "5", + "--steady-state-rounds", + "2", + "--steady-insert-batch-size", + "10", + "--steady-queries-per-round", + "5", + "--work-dir", + tempDir.resolve("stress").toString() + }); + } finally { + System.setOut(originalOut); + } + + String output = buffer.toString(StandardCharsets.UTF_8); + assertTrue(output.contains("STRESS_CONFIG docs=200")); + assertTrue(output.contains("INSERT_SUMMARY docs=200")); + assertTrue(output.contains("QUERY_SUMMARY count=20")); + assertTrue(output.contains("miss_count=")); + assertTrue(output.contains("recall=")); + assertTrue(output.contains("MEMORY_SUMMARY")); + assertTrue(output.contains("heap_before_mb=")); + assertTrue(output.contains("rss_before_mb=")); + assertTrue(output.contains("STEADY_STATE rounds=2")); + assertTrue(output.contains("ARTIFACT_DIR " + tempDir.resolve("stress"))); + assertTrue(output.contains("exec:java@run-stress")); + assertTrue(output.contains("-Dzvec.stress.args='--docs 100000'")); + } + + private static void assumeSupportedPlatform() { + String osName = System.getProperty("os.name", "").toLowerCase(); + String osArch = System.getProperty("os.arch", "").toLowerCase(); + Assumptions.assumeTrue(osName.contains("mac")); + Assumptions.assumeTrue(osArch.equals("aarch64") || osArch.equals("arm64")); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/CollectionStressChecksTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/CollectionStressChecksTest.java new file mode 100644 index 000000000..ed9a0829a --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/CollectionStressChecksTest.java @@ -0,0 +1,28 @@ +package org.zvec.perf; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.junit.jupiter.api.Test; +import org.zvec.Doc; + +class CollectionStressChecksTest { + @Test + void matchesExpectedHitWhenExpectedDocumentIsPresent() { + assertTrue( + CollectionStressMain.hasExpectedHit( + "doc_7", List.of(Doc.result("doc_7", 1.0), Doc.result("doc_9", 0.8)))); + assertTrue( + CollectionStressMain.hasExpectedHit( + "doc_7", List.of(Doc.result("doc_9", 1.0), Doc.result("doc_7", 0.8)))); + } + + @Test + void rejectsEmptyResultsAndMissingExpectedHit() { + assertFalse(CollectionStressMain.hasExpectedHit("doc_7", List.of())); + assertFalse( + CollectionStressMain.hasExpectedHit( + "doc_7", List.of(Doc.result("doc_9", 1.0), Doc.result("doc_8", 0.8)))); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/EncryptedFieldBenchmark.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/EncryptedFieldBenchmark.java new file mode 100644 index 000000000..05184ca8d --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/EncryptedFieldBenchmark.java @@ -0,0 +1,53 @@ +package org.zvec.perf; + +import java.security.SecureRandom; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.zvec.crypto.AesGcm256; +import org.zvec.crypto.AadEncoder; + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class EncryptedFieldBenchmark { + + @Param({"64", "1024", "1048576"}) + public int size; + + private byte[] key; + private byte[] nonce; + private byte[] plaintext; + private byte[] aad; + private byte[] ciphertext; + private AesGcm256 aead; + + @Setup + public void setup() { + aead = new AesGcm256(); + key = new byte[32]; + new SecureRandom().nextBytes(key); + nonce = new byte[12]; + new SecureRandom().nextBytes(nonce); + plaintext = new byte[size]; + new SecureRandom().nextBytes(plaintext); + aad = AadEncoder.encode("d1", "body", "docs"); + ciphertext = aead.seal(key, nonce, plaintext, aad); + } + + @Benchmark + public byte[] encrypt() { + return aead.seal(key, nonce, plaintext, aad); + } + + @Benchmark + public byte[] decrypt() { + return aead.open(key, nonce, ciphertext, aad); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/LatencyStatsTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/LatencyStatsTest.java new file mode 100644 index 000000000..53875f099 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/LatencyStatsTest.java @@ -0,0 +1,30 @@ +package org.zvec.perf; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +class LatencyStatsTest { + @Test + void computesPercentilesFromSamples() { + LatencyStats stats = + LatencyStats.fromNanos( + new long[] { + 1_000_000L, 2_000_000L, 3_000_000L, 4_000_000L, 5_000_000L + }); + + assertEquals(5, stats.count()); + assertEquals(1_000.0, stats.minMicros()); + assertEquals(3_000.0, stats.p50Micros()); + assertEquals(5_000.0, stats.p95Micros()); + assertEquals(5_000.0, stats.p99Micros()); + assertEquals(5_000.0, stats.maxMicros()); + assertEquals(3_000.0, stats.meanMicros()); + } + + @Test + void rejectsEmptySamples() { + assertThrows(IllegalArgumentException.class, () -> LatencyStats.fromNanos(new long[0])); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/StressOptionsTest.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/StressOptionsTest.java new file mode 100644 index 000000000..cc229d851 --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/StressOptionsTest.java @@ -0,0 +1,102 @@ +package org.zvec.perf; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.file.Path; +import org.junit.jupiter.api.Test; +import org.zvec.HnswIndexParams; +import org.zvec.HnswQueryParams; + +class StressOptionsTest { + @Test + void parsesExplicitArguments() { + StressOptions options = + StressOptions.parse( + new String[] { + "--docs", "1000000", + "--queries", "2500", + "--batch-size", "2000", + "--dimension", "256", + "--top-k", "20", + "--warmup-queries", "200", + "--steady-state-rounds", "12", + "--steady-insert-batch-size", "500", + "--steady-queries-per-round", "40", + "--concurrent-query-threads", "3", + "--concurrent-query-count", "24", + "--concurrent-mixed-threads", "4", + "--concurrent-mixed-rounds", "6", + "--concurrent-mixed-insert-batch-size", "7", + "--concurrent-mixed-queries-per-round", "8", + "--hnsw-m", "32", + "--hnsw-ef-construction", "300", + "--hnsw-ef", "128", + "--seed", "99", + "--work-dir", "target/perf/custom" + }); + + assertEquals(1_000_000, options.docCount()); + assertEquals(2_500, options.queryCount()); + assertEquals(2_000, options.batchSize()); + assertEquals(256, options.dimension()); + assertEquals(20, options.topK()); + assertEquals(200, options.warmupQueries()); + assertEquals(12, options.steadyStateRounds()); + assertEquals(500, options.steadyInsertBatchSize()); + assertEquals(40, options.steadyQueriesPerRound()); + assertEquals(3, options.concurrentQueryThreads()); + assertEquals(24, options.concurrentQueryCount()); + assertEquals(4, options.concurrentMixedThreads()); + assertEquals(6, options.concurrentMixedRounds()); + assertEquals(7, options.concurrentMixedInsertBatchSize()); + assertEquals(8, options.concurrentMixedQueriesPerRound()); + assertEquals(new HnswIndexParams(32, 300), options.hnswIndexParams()); + assertEquals(new HnswQueryParams(128, 0.0f, false, false), options.hnswQueryParams()); + assertEquals(99L, options.seed()); + assertEquals(Path.of("target/perf/custom"), options.workDir()); + } + + @Test + void usesDefaultsWhenArgumentsAreOmitted() { + StressOptions options = StressOptions.parse(new String[0]); + + assertEquals(100_000, options.docCount()); + assertEquals(1_000, options.queryCount()); + assertEquals(1_000, options.batchSize()); + assertEquals(128, options.dimension()); + assertEquals(10, options.topK()); + assertEquals(100, options.warmupQueries()); + assertEquals(20, options.steadyStateRounds()); + assertEquals(100, options.steadyInsertBatchSize()); + assertEquals(20, options.steadyQueriesPerRound()); + assertEquals(2, options.concurrentQueryThreads()); + assertEquals(20, options.concurrentQueryCount()); + assertEquals(2, options.concurrentMixedThreads()); + assertEquals(2, options.concurrentMixedRounds()); + assertEquals(5, options.concurrentMixedInsertBatchSize()); + assertEquals(5, options.concurrentMixedQueriesPerRound()); + assertNull(options.hnswIndexParams()); + assertNull(options.hnswQueryParams()); + assertEquals(7L, options.seed()); + assertEquals(Path.of("target/perf/zvec-stress"), options.workDir()); + } + + @Test + void rejectsUnknownArguments() { + assertThrows( + IllegalArgumentException.class, + () -> StressOptions.parse(new String[] {"--nope", "1"})); + } + + @Test + void rejectsPartialHnswIndexArguments() { + assertThrows( + IllegalArgumentException.class, + () -> StressOptions.parse(new String[] {"--hnsw-m", "32"})); + assertThrows( + IllegalArgumentException.class, + () -> StressOptions.parse(new String[] {"--hnsw-ef-construction", "300"})); + } +} diff --git a/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/ZvecJavaBindingBenchmark.java b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/ZvecJavaBindingBenchmark.java new file mode 100644 index 000000000..4be23ff4e --- /dev/null +++ b/java/zvec-java/zvec-java-api/src/test/java/org/zvec/perf/ZvecJavaBindingBenchmark.java @@ -0,0 +1,218 @@ +package org.zvec.perf; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; +import org.junit.jupiter.api.Assumptions; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.TimeValue; +import org.zvec.Collection; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.VectorQuery; +import org.zvec.Zvec; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@Fork(value = 1, jvmArgsAppend = {"--enable-native-access=ALL-UNNAMED"}) +@Warmup(iterations = 1, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 2, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Threads(1) +public class ZvecJavaBindingBenchmark { + private static final int DATASET_DOC_COUNT = 10_000; + private static final int LOAD_BATCH_SIZE = 1_000; + private static final int DIMENSION = 128; + private static final int TOP_K = 10; + private static final long SEED = 7L; + + public static void main(String[] args) throws RunnerException { + Options options = + new OptionsBuilder() + .include(ZvecJavaBindingBenchmark.class.getName()) + .forks(1) + .warmupIterations(1) + .warmupTime(TimeValue.milliseconds(500)) + .measurementIterations(2) + .measurementTime(TimeValue.milliseconds(500)) + .shouldDoGC(true) + .build(); + new Runner(options).run(); + } + + @Benchmark + public void queryProjectedScalarFields(QueryState state, Blackhole blackhole) { + List results = + state.collection.query( + VectorQuery.of("embedding", state.nextQueryVector()).topK(TOP_K).outputFields("title")); + blackhole.consume(results.size()); + if (!results.isEmpty()) { + Doc first = results.get(0); + blackhole.consume(first.id()); + blackhole.consume(first.fields().get("title")); + } + } + + @Benchmark + public void queryWithVectors(QueryState state, Blackhole blackhole) { + List results = + state.collection.query( + VectorQuery.of("embedding", state.nextQueryVector()) + .topK(TOP_K) + .outputFields("title") + .includeVector(true)); + blackhole.consume(results.size()); + if (!results.isEmpty()) { + Doc first = results.get(0); + blackhole.consume(first.id()); + blackhole.consume(first.fields().get("title")); + blackhole.consume(first.vectors().get("embedding")); + } + } + + @State(Scope.Benchmark) + public static class QueryState extends CollectionState { + private final AtomicInteger nextQueryVectorIndex = new AtomicInteger(); + private float[][] queryVectors; + + @Override + protected void afterSetUp() { + queryVectors = new float[256][]; + for (int i = 0; i < queryVectors.length; i++) { + queryVectors[i] = PerfData.querySample(i, DIMENSION, SEED, TOP_K).vector(); + } + } + + float[] nextQueryVector() { + int index = Math.floorMod(nextQueryVectorIndex.getAndIncrement(), queryVectors.length); + return queryVectors[index]; + } + } + + private abstract static class CollectionState { + protected Collection collection; + protected Path workDir; + + @Setup(Level.Trial) + public final void setUp() throws IOException { + assumeSupportedPlatform(); + workDir = Files.createTempDirectory("zvec-jmh-"); + CollectionSchema schema = PerfData.schema("perf_docs", DIMENSION); + collection = Zvec.createAndOpen(workDir.resolve("collection").toString(), schema); + + try { + for (int startDocIndex = 0; startDocIndex < DATASET_DOC_COUNT; startDocIndex += LOAD_BATCH_SIZE) { + int batchSize = Math.min(LOAD_BATCH_SIZE, DATASET_DOC_COUNT - startDocIndex); + int inserted = collection.insert(PerfData.docs(startDocIndex, batchSize, DIMENSION, SEED)); + if (inserted != batchSize) { + throw new IllegalStateException( + "Inserted count mismatch: expected " + batchSize + ", got " + inserted); + } + } + collection.flush(); + afterSetUp(); + } catch (RuntimeException e) { + tearDownQuietly(); + throw e; + } + } + + @TearDown(Level.Trial) + public final void tearDown() throws IOException { + IOException failure = null; + if (collection != null) { + try { + collection.close(); + } catch (RuntimeException e) { + failure = new IOException("Failed to close benchmark collection", e); + } finally { + collection = null; + } + } + + if (workDir != null) { + try { + deleteRecursively(workDir); + } catch (IOException e) { + if (failure == null) { + failure = e; + } else { + failure.addSuppressed(e); + } + } finally { + workDir = null; + } + } + + if (failure != null) { + throw failure; + } + } + + private void tearDownQuietly() { + try { + tearDown(); + } catch (IOException ignored) { + } + } + + protected void afterSetUp() {} + } + + private static void assumeSupportedPlatform() { + String osName = System.getProperty("os.name", "").toLowerCase(); + String osArch = System.getProperty("os.arch", "").toLowerCase(); + Assumptions.assumeTrue(osName.contains("mac")); + Assumptions.assumeTrue(osArch.equals("aarch64") || osArch.equals("arm64")); + } + + private static void deleteRecursively(Path root) throws IOException { + if (root == null || Files.notExists(root)) { + return; + } + + try (Stream paths = Files.walk(root)) { + paths.sorted(Comparator.reverseOrder()) + .forEach( + path -> { + try { + Files.deleteIfExists(path); + } catch (IOException e) { + throw new RecursiveDeleteException(e); + } + }); + } catch (RecursiveDeleteException e) { + throw e.cause; + } + } + + private static final class RecursiveDeleteException extends RuntimeException { + private final IOException cause; + + private RecursiveDeleteException(IOException cause) { + super(cause); + this.cause = cause; + } + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionInsertIntegrationTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionInsertIntegrationTest.java new file mode 100644 index 000000000..c27086c09 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionInsertIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class FfmCollectionInsertIntegrationTest extends AbstractCollectionInsertIntegrationTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionLifecycleIntegrationTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionLifecycleIntegrationTest.java new file mode 100644 index 000000000..0958ade56 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionLifecycleIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class FfmCollectionLifecycleIntegrationTest extends AbstractCollectionLifecycleIntegrationTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionQueryIntegrationTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionQueryIntegrationTest.java new file mode 100644 index 000000000..6874841e4 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmCollectionQueryIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class FfmCollectionQueryIntegrationTest extends AbstractCollectionQueryIntegrationTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmConcurrentEncryptedInsertIntegrationTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmConcurrentEncryptedInsertIntegrationTest.java new file mode 100644 index 000000000..3e6467477 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmConcurrentEncryptedInsertIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class FfmConcurrentEncryptedInsertIntegrationTest extends AbstractConcurrentEncryptedInsertIntegrationTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmEncryptedCollectionRoundTripIntegrationTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmEncryptedCollectionRoundTripIntegrationTest.java new file mode 100644 index 000000000..c34d61840 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmEncryptedCollectionRoundTripIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class FfmEncryptedCollectionRoundTripIntegrationTest extends AbstractEncryptedCollectionRoundTripIntegrationTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmEncryptionMetadataMismatchIntegrationTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmEncryptionMetadataMismatchIntegrationTest.java new file mode 100644 index 000000000..57d1adde0 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmEncryptionMetadataMismatchIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class FfmEncryptionMetadataMismatchIntegrationTest extends AbstractEncryptionMetadataMismatchIntegrationTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmKeyRotationIntegrationTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmKeyRotationIntegrationTest.java new file mode 100644 index 000000000..854b2bee8 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmKeyRotationIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class FfmKeyRotationIntegrationTest extends AbstractKeyRotationIntegrationTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmLargeEncryptedPayloadIntegrationTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmLargeEncryptedPayloadIntegrationTest.java new file mode 100644 index 000000000..7c0f2059c --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmLargeEncryptedPayloadIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class FfmLargeEncryptedPayloadIntegrationTest extends AbstractLargeEncryptedPayloadIntegrationTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmOpenWithoutKeysIntegrationTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmOpenWithoutKeysIntegrationTest.java new file mode 100644 index 000000000..dbd6cee22 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmOpenWithoutKeysIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class FfmOpenWithoutKeysIntegrationTest extends AbstractOpenWithoutKeysIntegrationTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmQuickStartTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmQuickStartTest.java new file mode 100644 index 000000000..8abdd690b --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/FfmQuickStartTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class FfmQuickStartTest extends AbstractQuickStartTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmCollectionSetActiveKeyIdTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmCollectionSetActiveKeyIdTest.java new file mode 100644 index 000000000..ce32e1351 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmCollectionSetActiveKeyIdTest.java @@ -0,0 +1,3 @@ +package org.zvec.crypto; + +final class FfmCollectionSetActiveKeyIdTest extends AbstractCollectionSetActiveKeyIdTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmZvecCreateAndOpenWithProviderTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmZvecCreateAndOpenWithProviderTest.java new file mode 100644 index 000000000..c04a407b0 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmZvecCreateAndOpenWithProviderTest.java @@ -0,0 +1,3 @@ +package org.zvec.crypto; + +final class FfmZvecCreateAndOpenWithProviderTest extends AbstractZvecCreateAndOpenWithProviderTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmZvecOpenWithKeysTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmZvecOpenWithKeysTest.java new file mode 100644 index 000000000..b6d7ed9ab --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/crypto/FfmZvecOpenWithKeysTest.java @@ -0,0 +1,3 @@ +package org.zvec.crypto; + +final class FfmZvecOpenWithKeysTest extends AbstractZvecOpenWithKeysTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmNativeLoaderTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmNativeLoaderTest.java new file mode 100644 index 000000000..ca2ade815 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmNativeLoaderTest.java @@ -0,0 +1,57 @@ +package org.zvec.internal.ffm; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.file.Path; +import org.junit.jupiter.api.Test; + +class FfmNativeLoaderTest { + @Test + void mapsMacOsArm64ToBundledDylib() { + assertEquals( + "/META-INF/native/darwin-aarch64/libzvec_c_api.dylib", + FfmNativeLoader.platformResourcePath("Mac OS X", "aarch64")); + } + + @Test + void mapsLinuxAndWindowsToBundledLibraries() { + assertEquals( + "/META-INF/native/linux-x86_64/libzvec_c_api.so", + FfmNativeLoader.platformResourcePath("Linux", "amd64")); + assertEquals( + "/META-INF/native/linux-aarch64/libzvec_c_api.so", + FfmNativeLoader.platformResourcePath("Linux", "aarch64")); + assertEquals( + "/META-INF/native/windows-x86_64/zvec_c_api.dll", + FfmNativeLoader.platformResourcePath("Windows 10", "amd64")); + } + + @Test + void resolvesPlatformSpecificLibraryNames() { + assertEquals("libzvec_c_api.dylib", FfmNativeLoader.cApiLibraryName("darwin-aarch64")); + assertEquals("libzvec_c_api.so", FfmNativeLoader.cApiLibraryName("linux-x86_64")); + assertEquals("zvec_c_api.dll", FfmNativeLoader.cApiLibraryName("windows-x86_64")); + } + + @Test + void rejectsUnsupportedPlatforms() { + assertThrows( + IllegalStateException.class, + () -> FfmNativeLoader.platformResourcePath("Linux", "riscv64")); + } + + @Test + void createsUniqueExtractionTargets() { + Path first = FfmNativeLoader.extractionTarget("/META-INF/native/darwin-aarch64/libzvec_c_api.dylib"); + Path second = FfmNativeLoader.extractionTarget("/META-INF/native/darwin-aarch64/libzvec_c_api.dylib"); + + assertEquals("libzvec_c_api.dylib", first.getFileName().toString()); + assertEquals("libzvec_c_api.dylib", second.getFileName().toString()); + assertNotEquals(first, second); + assertNotEquals(first.getParent(), second.getParent()); + assertTrue(first.getParent().getFileName().toString().startsWith("zvec-java-darwin-aarch64-")); + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmNativeVersionTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmNativeVersionTest.java new file mode 100644 index 000000000..faa72909c --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmNativeVersionTest.java @@ -0,0 +1,12 @@ +package org.zvec.internal.ffm; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +import org.junit.jupiter.api.Test; + +class FfmNativeVersionTest { + @Test + void returnsNonBlankVersionString() { + assertFalse(FfmNative.version().isBlank()); + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmQueriesTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmQueriesTest.java new file mode 100644 index 000000000..147bf3b3f --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/internal/ffm/FfmQueriesTest.java @@ -0,0 +1,62 @@ +package org.zvec.internal.ffm; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.zvec.DataType; +import org.zvec.HnswQueryParams; +import org.zvec.TuningProfile; +import org.zvec.VectorQuery; +import org.zvec.VectorSchema; + +class FfmQueriesTest { + @Test + void attachesHnswParamsWhenRuntimeSchemaCarriesHnswIndexParams() { + VectorSchema runtimeSchema = + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withHnswIndex(new org.zvec.HnswIndexParams(16, 200)); + + assertTrue( + FfmQueries.shouldAttachHnswParams( + runtimeSchema, VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}))); + } + + @Test + void skipsHnswParamsWhenRuntimeSchemaHasNoHnswContext() { + VectorSchema runtimeSchema = new VectorSchema("embedding", DataType.VECTOR_FP32, 4); + + assertFalse( + FfmQueries.shouldAttachHnswParams( + runtimeSchema, VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}))); + } + + @Test + void skipsExplicitQueryHnswTuningWhenRuntimeSchemaHasNoHnswContext() { + VectorSchema runtimeSchema = new VectorSchema("embedding", DataType.VECTOR_FP32, 4); + VectorQuery query = + VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}) + .withTuningProfile(TuningProfile.ACCURATE) + .hnsw(new HnswQueryParams(128, 0.0f, false, false)); + + assertFalse(FfmQueries.shouldAttachHnswParams(runtimeSchema, query)); + } + + @Test + void resolvesQueryDefaultsFromPublicSchemaWhenRuntimeSchemaCarriesHnswContext() { + VectorSchema runtimeSchema = + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withHnswIndex(new org.zvec.HnswIndexParams(32, 400)); + VectorSchema publicSchema = + new VectorSchema("embedding", DataType.VECTOR_FP32, 4) + .withTuningProfile(TuningProfile.ACCURATE, 1_000_000L); + + assertEquals( + new HnswQueryParams(128, 0.0f, false, false), + FfmQueries.resolveAttachedHnswParams( + runtimeSchema, + publicSchema, + VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}))); + } +} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmCollectionConcurrentStressMainTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmCollectionConcurrentStressMainTest.java new file mode 100644 index 000000000..66c0fd1b1 --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmCollectionConcurrentStressMainTest.java @@ -0,0 +1,3 @@ +package org.zvec.perf; + +final class FfmCollectionConcurrentStressMainTest extends AbstractCollectionConcurrentStressMainTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmCollectionStressMainTest.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmCollectionStressMainTest.java new file mode 100644 index 000000000..43336bdba --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmCollectionStressMainTest.java @@ -0,0 +1,3 @@ +package org.zvec.perf; + +final class FfmCollectionStressMainTest extends AbstractCollectionStressMainTest {} diff --git a/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmDocsBenchmark.java b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmDocsBenchmark.java new file mode 100644 index 000000000..47170df2b --- /dev/null +++ b/java/zvec-java/zvec-java-ffm/src/test/java/org/zvec/perf/FfmDocsBenchmark.java @@ -0,0 +1,83 @@ +package org.zvec.perf; + +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Assumptions; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.TimeValue; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.internal.ffm.FfmDocs; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@Fork(value = 1, jvmArgsAppend = {"--enable-native-access=ALL-UNNAMED"}) +@Warmup(iterations = 1, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 2, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Threads(1) +public class FfmDocsBenchmark { + private static final int DATASET_DOC_COUNT = 10_000; + private static final int SMALL_BATCH_SIZE = 16; + private static final int DIMENSION = 128; + private static final long SEED = 7L; + + public static void main(String[] args) throws RunnerException { + Options options = + new OptionsBuilder() + .include(FfmDocsBenchmark.class.getName()) + .forks(1) + .warmupIterations(1) + .warmupTime(TimeValue.milliseconds(500)) + .measurementIterations(2) + .measurementTime(TimeValue.milliseconds(500)) + .shouldDoGC(true) + .build(); + new Runner(options).run(); + } + + @Benchmark + public void marshalInsertSmallBatch(MarshalState state, Blackhole blackhole) { + var nativeDocs = FfmDocs.toFfmDocs(state.docs, state.schema); + try { + blackhole.consume(nativeDocs.size()); + } finally { + FfmDocs.destroyAll(nativeDocs); + } + } + + @State(Scope.Benchmark) + public static class MarshalState { + CollectionSchema schema; + List docs; + + @Setup(Level.Trial) + public void setUp() { + assumeSupportedPlatform(); + schema = PerfData.schema("perf_docs", DIMENSION); + docs = PerfData.docs(DATASET_DOC_COUNT, SMALL_BATCH_SIZE, DIMENSION, SEED); + } + } + + private static void assumeSupportedPlatform() { + String osName = System.getProperty("os.name", "").toLowerCase(); + String osArch = System.getProperty("os.arch", "").toLowerCase(); + Assumptions.assumeTrue(osName.contains("mac")); + Assumptions.assumeTrue(osArch.equals("aarch64") || osArch.equals("arm64")); + } +} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionInsertIntegrationTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionInsertIntegrationTest.java new file mode 100644 index 000000000..ab6e0ecde --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionInsertIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class JniCollectionInsertIntegrationTest extends AbstractCollectionInsertIntegrationTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionLifecycleIntegrationTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionLifecycleIntegrationTest.java new file mode 100644 index 000000000..671dae618 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionLifecycleIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class JniCollectionLifecycleIntegrationTest extends AbstractCollectionLifecycleIntegrationTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionQueryIntegrationTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionQueryIntegrationTest.java new file mode 100644 index 000000000..df4e34d98 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniCollectionQueryIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class JniCollectionQueryIntegrationTest extends AbstractCollectionQueryIntegrationTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniConcurrentEncryptedInsertIntegrationTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniConcurrentEncryptedInsertIntegrationTest.java new file mode 100644 index 000000000..661337ba5 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniConcurrentEncryptedInsertIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class JniConcurrentEncryptedInsertIntegrationTest extends AbstractConcurrentEncryptedInsertIntegrationTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniEncryptedCollectionRoundTripIntegrationTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniEncryptedCollectionRoundTripIntegrationTest.java new file mode 100644 index 000000000..ee66118d0 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniEncryptedCollectionRoundTripIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class JniEncryptedCollectionRoundTripIntegrationTest extends AbstractEncryptedCollectionRoundTripIntegrationTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniEncryptionMetadataMismatchIntegrationTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniEncryptionMetadataMismatchIntegrationTest.java new file mode 100644 index 000000000..f459a5ad8 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniEncryptionMetadataMismatchIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class JniEncryptionMetadataMismatchIntegrationTest extends AbstractEncryptionMetadataMismatchIntegrationTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniKeyRotationIntegrationTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniKeyRotationIntegrationTest.java new file mode 100644 index 000000000..dbe952175 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniKeyRotationIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class JniKeyRotationIntegrationTest extends AbstractKeyRotationIntegrationTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniLargeEncryptedPayloadIntegrationTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniLargeEncryptedPayloadIntegrationTest.java new file mode 100644 index 000000000..2324eeb1c --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniLargeEncryptedPayloadIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class JniLargeEncryptedPayloadIntegrationTest extends AbstractLargeEncryptedPayloadIntegrationTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniOpenWithoutKeysIntegrationTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniOpenWithoutKeysIntegrationTest.java new file mode 100644 index 000000000..aedaa7c24 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniOpenWithoutKeysIntegrationTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class JniOpenWithoutKeysIntegrationTest extends AbstractOpenWithoutKeysIntegrationTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniQuickStartTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniQuickStartTest.java new file mode 100644 index 000000000..cd49e0261 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/JniQuickStartTest.java @@ -0,0 +1,3 @@ +package org.zvec; + +final class JniQuickStartTest extends AbstractQuickStartTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniCollectionSetActiveKeyIdTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniCollectionSetActiveKeyIdTest.java new file mode 100644 index 000000000..3784a94bb --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniCollectionSetActiveKeyIdTest.java @@ -0,0 +1,3 @@ +package org.zvec.crypto; + +final class JniCollectionSetActiveKeyIdTest extends AbstractCollectionSetActiveKeyIdTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniZvecCreateAndOpenWithProviderTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniZvecCreateAndOpenWithProviderTest.java new file mode 100644 index 000000000..ca03f61f0 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniZvecCreateAndOpenWithProviderTest.java @@ -0,0 +1,3 @@ +package org.zvec.crypto; + +final class JniZvecCreateAndOpenWithProviderTest extends AbstractZvecCreateAndOpenWithProviderTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniZvecOpenWithKeysTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniZvecOpenWithKeysTest.java new file mode 100644 index 000000000..de101630f --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/crypto/JniZvecOpenWithKeysTest.java @@ -0,0 +1,3 @@ +package org.zvec.crypto; + +final class JniZvecOpenWithKeysTest extends AbstractZvecOpenWithKeysTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/internal/jni/JniNativeLoaderTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/internal/jni/JniNativeLoaderTest.java new file mode 100644 index 000000000..309cd2064 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/internal/jni/JniNativeLoaderTest.java @@ -0,0 +1,59 @@ +package org.zvec.internal.jni; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.file.Path; +import org.junit.jupiter.api.Test; + +class JniNativeLoaderTest { + @Test + void resolvesDarwinAarch64ResourceDir() { + assertEquals( + "/META-INF/native/darwin-aarch64", + JniNativeLoader.platformResourceDir("Mac OS X", "aarch64")); + assertEquals( + "/META-INF/native/darwin-aarch64", + JniNativeLoader.platformResourceDir("Darwin", "arm64")); + } + + @Test + void resolvesLinuxAndWindowsResourceDirs() { + assertEquals( + "/META-INF/native/linux-x86_64", + JniNativeLoader.platformResourceDir("Linux", "amd64")); + assertEquals( + "/META-INF/native/linux-aarch64", + JniNativeLoader.platformResourceDir("Linux", "aarch64")); + assertEquals( + "/META-INF/native/windows-x86_64", + JniNativeLoader.platformResourceDir("Windows 10", "amd64")); + } + + @Test + void resolvesPlatformSpecificLibraryNames() { + assertEquals("libzvec_c_api.dylib", JniNativeLoader.cApiLibraryName("darwin-aarch64")); + assertEquals("libzvec_java_jni.dylib", JniNativeLoader.jniLibraryName("darwin-aarch64")); + assertEquals("libzvec_c_api.so", JniNativeLoader.cApiLibraryName("linux-x86_64")); + assertEquals("libzvec_java_jni.so", JniNativeLoader.jniLibraryName("linux-x86_64")); + assertEquals("zvec_c_api.dll", JniNativeLoader.cApiLibraryName("windows-x86_64")); + assertEquals("zvec_java_jni.dll", JniNativeLoader.jniLibraryName("windows-x86_64")); + } + + @Test + void rejectsUnsupportedPlatform() { + IllegalStateException ex = + assertThrows( + IllegalStateException.class, + () -> JniNativeLoader.platformResourceDir("Linux", "riscv64")); + assertTrue(ex.getMessage().contains("Unsupported zvec-java-jni platform")); + } + + @Test + void extractionTargetUsesPlatformScopedTempDir() { + Path target = JniNativeLoader.extractionTarget("/META-INF/native/darwin-aarch64"); + + assertTrue(target.getFileName().toString().startsWith("zvec-java-darwin-aarch64-")); + } +} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/internal/jni/JniNativeVersionTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/internal/jni/JniNativeVersionTest.java new file mode 100644 index 000000000..ec1d27105 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/internal/jni/JniNativeVersionTest.java @@ -0,0 +1,12 @@ +package org.zvec.internal.jni; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +import org.junit.jupiter.api.Test; + +class JniNativeVersionTest { + @Test + void exposesNativeVersion() { + assertFalse(JniNative.version().isBlank()); + } +} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/perf/JniCollectionConcurrentStressMainTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/perf/JniCollectionConcurrentStressMainTest.java new file mode 100644 index 000000000..23afc8b0e --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/perf/JniCollectionConcurrentStressMainTest.java @@ -0,0 +1,3 @@ +package org.zvec.perf; + +final class JniCollectionConcurrentStressMainTest extends AbstractCollectionConcurrentStressMainTest {} diff --git a/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/perf/JniCollectionStressMainTest.java b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/perf/JniCollectionStressMainTest.java new file mode 100644 index 000000000..b8bcbd4eb --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/test/java/org/zvec/perf/JniCollectionStressMainTest.java @@ -0,0 +1,3 @@ +package org.zvec.perf; + +final class JniCollectionStressMainTest extends AbstractCollectionStressMainTest {} From c1be84d2821058656b346436e458cd17c74a5efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=A3=9E?= Date: Thu, 14 May 2026 11:38:33 -0400 Subject: [PATCH 3/5] build(java): add platform-aware native packaging --- .gitignore | 10 +- CMakeLists.txt | 19 +- .../src/main/native/CMakeLists.txt | 66 ++++++ scripts/build_java_native.sh | 222 ++++++++++++++++++ thirdparty/lz4/CMakeLists.txt | 1 + thirdparty/sparsehash/CMakeLists.txt | 5 +- 6 files changed, 320 insertions(+), 3 deletions(-) create mode 100644 java/zvec-java/zvec-java-jni/src/main/native/CMakeLists.txt create mode 100755 scripts/build_java_native.sh diff --git a/.gitignore b/.gitignore index 38c769e2a..0ca493ad2 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,14 @@ dist html *.lcov.info +# Java / Maven build outputs +target/ +*.jar +*.class +*.dll +*.dylib +*.so + # Dependencies /node_modules @@ -51,4 +59,4 @@ allure-* !build_android.sh !build_ios.sh - +!scripts/build_java_native.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index a33e61e99..ad6ed76c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,7 @@ set(CC_CXX_STANDARD 17) if(MSVC) set(INTTYPES_FORMAT VC7) add_compile_options(/FS) # handle .pdb concurrency + add_compile_options(/utf-8) add_compile_options(/EHsc) # def c++ exception behavior add_compile_options(/Zc:preprocessor /Zc:__cplusplus) add_compile_options(/we4716) # -Werror=return-type @@ -73,6 +74,12 @@ message(STATUS "BUILD_PYTHON_BINDINGS:${BUILD_PYTHON_BINDINGS}") option(BUILD_C_BINDINGS "Build C bindings" ON) message(STATUS "BUILD_C_BINDINGS:${BUILD_C_BINDINGS}") +option(BUILD_JAVA_JNI_BINDING "Build Java JNI binding" OFF) +message(STATUS "BUILD_JAVA_JNI_BINDING:${BUILD_JAVA_JNI_BINDING}") + +option(BUILD_TESTS "Build tests" ON) +message(STATUS "BUILD_TESTS:${BUILD_TESTS}") + option(BUILD_TOOLS "Build tools" ON) message(STATUS "BUILD_TOOLS:${BUILD_TOOLS}") @@ -117,7 +124,17 @@ message(STATUS "USE_OSS_MIRROR:${USE_OSS_MIRROR}") cc_directory(thirdparty) cc_directories(src) -cc_directories(tests) + +if(BUILD_TESTS) + cc_directories(tests) +endif() + +if(BUILD_JAVA_JNI_BINDING) + if(NOT BUILD_C_BINDINGS) + message(FATAL_ERROR "BUILD_JAVA_JNI_BINDING requires BUILD_C_BINDINGS") + endif() + add_subdirectory(java/zvec-java/zvec-java-jni/src/main/native) +endif() if(BUILD_TOOLS) cc_directories(tools) diff --git a/java/zvec-java/zvec-java-jni/src/main/native/CMakeLists.txt b/java/zvec-java/zvec-java-jni/src/main/native/CMakeLists.txt new file mode 100644 index 000000000..b3139eba7 --- /dev/null +++ b/java/zvec-java/zvec-java-jni/src/main/native/CMakeLists.txt @@ -0,0 +1,66 @@ +set(ZVEC_JAVA_HOME "$ENV{JAVA_HOME}" CACHE PATH "JDK home used for JNI headers") +if(NOT ZVEC_JAVA_HOME) + message(FATAL_ERROR "JAVA_HOME must be set to build zvec_java_jni") +endif() + +set(ZVEC_JAVA_JNI_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}" CACHE PATH "zvec Java JNI source directory") + +if(WIN32) + set(ZVEC_JNI_PLATFORM_INCLUDE_DIR win32) +elseif(APPLE) + set(ZVEC_JNI_PLATFORM_INCLUDE_DIR darwin) +elseif(UNIX) + set(ZVEC_JNI_PLATFORM_INCLUDE_DIR linux) +else() + message(FATAL_ERROR "Unsupported JNI platform: ${CMAKE_SYSTEM_NAME}") +endif() + +set(ZVEC_JNI_INCLUDE_DIRS + "${ZVEC_JAVA_HOME}/include" + "${ZVEC_JAVA_HOME}/include/${ZVEC_JNI_PLATFORM_INCLUDE_DIR}" +) + +foreach(ZVEC_JNI_INCLUDE_DIR IN LISTS ZVEC_JNI_INCLUDE_DIRS) + if(NOT EXISTS "${ZVEC_JNI_INCLUDE_DIR}") + message(FATAL_ERROR "JNI include directory not found: ${ZVEC_JNI_INCLUDE_DIR}") + endif() +endforeach() + +add_library(zvec_java_jni SHARED + "${ZVEC_JAVA_JNI_SOURCE_DIR}/zvec_java_jni.cc" +) + +set_target_properties(zvec_java_jni PROPERTIES + OUTPUT_NAME "zvec_java_jni" + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CXX_EXTENSIONS OFF + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN ON +) + +target_include_directories(zvec_java_jni + PRIVATE + ${ZVEC_JNI_INCLUDE_DIRS} + ${PROJECT_SOURCE_DIR}/src/include + ${PROJECT_BINARY_DIR}/src/generated +) + +target_link_libraries(zvec_java_jni + PRIVATE + zvec_c_api +) + +if(APPLE) + set_target_properties(zvec_java_jni PROPERTIES + MACOSX_RPATH ON + BUILD_RPATH "@loader_path" + INSTALL_RPATH "@loader_path" + ) +elseif(UNIX) + set_target_properties(zvec_java_jni PROPERTIES + BUILD_RPATH "$ORIGIN" + INSTALL_RPATH "$ORIGIN" + ) +endif() diff --git a/scripts/build_java_native.sh b/scripts/build_java_native.sh new file mode 100755 index 000000000..5eb951b80 --- /dev/null +++ b/scripts/build_java_native.sh @@ -0,0 +1,222 @@ +#!/bin/bash +set -euo pipefail + +ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd) + +usage() { + echo "usage: $0 [ffm|jni] [host|platform] [jni-source-dir]" >&2 + echo "platform: darwin-aarch64, darwin-x86_64, linux-aarch64, linux-x86_64, windows-x86_64" >&2 +} + +cpu_count() { + if command -v getconf >/dev/null 2>&1; then + local count + count=$(getconf _NPROCESSORS_ONLN 2>/dev/null || true) + if [ -n "$count" ] && [ "$count" -gt 0 ] 2>/dev/null; then + echo "$count" + return + fi + fi + if command -v sysctl >/dev/null 2>&1; then + local count + count=$(sysctl -n hw.ncpu 2>/dev/null || true) + if [ -n "$count" ] && [ "$count" -gt 0 ] 2>/dev/null; then + echo "$count" + return + fi + fi + if [ -n "${NUMBER_OF_PROCESSORS:-}" ]; then + echo "$NUMBER_OF_PROCESSORS" + return + fi + echo 2 +} + +normalize_arch() { + case "$(echo "$1" | tr '[:upper:]' '[:lower:]')" in + x86_64|amd64) echo "x86_64" ;; + aarch64|arm64) echo "aarch64" ;; + *) + echo "error: unsupported architecture: $1" >&2 + exit 1 + ;; + esac +} + +detect_host_platform() { + local uname_s uname_m os arch + uname_s=$(uname -s 2>/dev/null || echo "${OS:-unknown}") + uname_m=$(uname -m 2>/dev/null || echo "${PROCESSOR_ARCHITECTURE:-unknown}") + arch=$(normalize_arch "$uname_m") + + case "$uname_s" in + Darwin*) os="darwin" ;; + Linux*) os="linux" ;; + MINGW*|MSYS*|CYGWIN*|Windows_NT*) os="windows" ;; + *) + echo "error: unsupported operating system: $uname_s" >&2 + exit 1 + ;; + esac + + echo "$(normalize_platform "$os-$arch")" +} + +normalize_platform() { + case "$(echo "$1" | tr '[:upper:]' '[:lower:]')" in + host) detect_host_platform ;; + darwin-aarch64|macos-aarch64|macos-arm64|osx-aarch64|osx-arm64) echo "darwin-aarch64" ;; + darwin-x86_64|darwin-amd64|macos-x86_64|macos-amd64|osx-x86_64|osx-amd64) echo "darwin-x86_64" ;; + linux-aarch64|linux-arm64) echo "linux-aarch64" ;; + linux-x86_64|linux-amd64) echo "linux-x86_64" ;; + windows-x86_64|windows-amd64|win32-x86_64|win32-amd64) echo "windows-x86_64" ;; + *) + echo "error: unsupported Java native platform: $1" >&2 + usage + exit 1 + ;; + esac +} + +library_name() { + local base=$1 + local platform=$2 + case "$platform" in + darwin-*) echo "lib${base}.dylib" ;; + linux-*) echo "lib${base}.so" ;; + windows-*) echo "${base}.dll" ;; + *) + echo "error: unsupported Java native platform: $platform" >&2 + exit 1 + ;; + esac +} + +find_built_library() { + local build_dir=$1 + local library_name=$2 + local path + path=$(find "$build_dir" -name "$library_name" -print -quit) + if [ -z "$path" ]; then + echo "error: $library_name not found under $build_dir" >&2 + exit 1 + fi + echo "$path" +} + +fix_macos_install_names() { + local core_source=$1 + local core_output=$2 + local jni_output=$3 + local core_name=$4 + local jni_name=$5 + + if ! command -v install_name_tool >/dev/null 2>&1; then + return + fi + + install_name_tool -id "@rpath/$core_name" "$core_output" || true + install_name_tool -id "@rpath/$jni_name" "$jni_output" || true + install_name_tool -change "$core_source" "@loader_path/$core_name" "$jni_output" || true + install_name_tool -change "$core_output" "@loader_path/$core_name" "$jni_output" || true + install_name_tool -change "@rpath/$core_name" "@loader_path/$core_name" "$jni_output" || true +} + +NATIVE_ROOT_DIR=${1:-} +MODE=${2:-ffm} +PLATFORM_ARG=${3:-host} +JNI_SRC_DIR=${4:-} + +if [ -z "$NATIVE_ROOT_DIR" ]; then + usage + exit 1 +fi + +case "$MODE" in + ffm|jni) ;; + *) + echo "error: unsupported mode: $MODE" >&2 + usage + exit 1 + ;; +esac + +# Backward compatibility with the old jni signature: +# build_java_native.sh jni +if [ "$MODE" = "jni" ] && [ -z "$JNI_SRC_DIR" ] && [ -d "$PLATFORM_ARG" ]; then + JNI_SRC_DIR=$PLATFORM_ARG + PLATFORM_ARG=host +fi + +PLATFORM=$(normalize_platform "$PLATFORM_ARG") +HOST_PLATFORM=$(detect_host_platform) +if [ "$PLATFORM" != "$HOST_PLATFORM" ] && [ "${ZVEC_ALLOW_CROSS:-0}" != "1" ]; then + echo "error: requested Java native platform $PLATFORM, but host platform is $HOST_PLATFORM" >&2 + echo "build on a matching host, or set ZVEC_ALLOW_CROSS=1 with an appropriate CMake toolchain" >&2 + exit 1 +fi +CORE_COUNT=$(cpu_count) +OUTPUT_DIR="$NATIVE_ROOT_DIR/$PLATFORM" +BUILD_PLATFORM=${PLATFORM//-/_} +BUILD_DIR=${ZVEC_NATIVE_BUILD_DIR:-"$ROOT_DIR/build_java_native_$BUILD_PLATFORM"} +CMAKE_LOG="$BUILD_DIR/native-build.log" +CORE_LIBRARY_NAME=$(library_name zvec_c_api "$PLATFORM") +JNI_LIBRARY_NAME=$(library_name zvec_java_jni "$PLATFORM") +BUILD_JAVA_JNI_BINDING=OFF + +if [ "$MODE" = "jni" ]; then + BUILD_JAVA_JNI_BINDING=ON + if [ -z "$JNI_SRC_DIR" ]; then + echo "error: JNI source directory is required in jni mode" >&2 + exit 1 + fi + if [ -z "${JAVA_HOME:-}" ]; then + echo "error: JAVA_HOME must be set to build JNI library" >&2 + exit 1 + fi +fi + +mkdir -p "$OUTPUT_DIR" +mkdir -p "$BUILD_DIR" + +if [ -e "$ROOT_DIR/.git" ]; then + git -C "$ROOT_DIR" submodule update --init --recursive --jobs "$CORE_COUNT" +fi + +if ! cmake -S "$ROOT_DIR" -B "$BUILD_DIR" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 \ + -DBUILD_C_BINDINGS=ON \ + -DBUILD_JAVA_JNI_BINDING="$BUILD_JAVA_JNI_BINDING" \ + -DZVEC_JAVA_HOME="${JAVA_HOME:-}" \ + -DZVEC_JAVA_JNI_SOURCE_DIR="$JNI_SRC_DIR" \ + -DBUILD_PYTHON_BINDINGS=OFF \ + -DBUILD_TESTS=OFF \ + -DBUILD_TOOLS=OFF >"$CMAKE_LOG" 2>&1; then + echo "error: failed to configure zvec native libraries for $PLATFORM; see $CMAKE_LOG" >&2 + exit 1 +fi + +if ! cmake --build "$BUILD_DIR" --target zvec_c_api --config Release -j"$CORE_COUNT" >>"$CMAKE_LOG" 2>&1; then + echo "error: failed to build $CORE_LIBRARY_NAME; see $CMAKE_LOG" >&2 + exit 1 +fi + +CORE_PATH=$(find_built_library "$BUILD_DIR" "$CORE_LIBRARY_NAME") +CORE_OUTPUT="$OUTPUT_DIR/$CORE_LIBRARY_NAME" +cp "$CORE_PATH" "$CORE_OUTPUT" + +if [ "$MODE" = "jni" ]; then + if ! cmake --build "$BUILD_DIR" --target zvec_java_jni --config Release -j"$CORE_COUNT" >>"$CMAKE_LOG" 2>&1; then + echo "error: failed to build $JNI_LIBRARY_NAME; see $CMAKE_LOG" >&2 + exit 1 + fi + + JNI_PATH=$(find_built_library "$BUILD_DIR" "$JNI_LIBRARY_NAME") + JNI_OUTPUT="$OUTPUT_DIR/$JNI_LIBRARY_NAME" + cp "$JNI_PATH" "$JNI_OUTPUT" + + if [[ "$PLATFORM" == darwin-* ]]; then + fix_macos_install_names "$CORE_PATH" "$CORE_OUTPUT" "$JNI_OUTPUT" "$CORE_LIBRARY_NAME" "$JNI_LIBRARY_NAME" + fi +fi diff --git a/thirdparty/lz4/CMakeLists.txt b/thirdparty/lz4/CMakeLists.txt index 5e73245f2..e06087e05 100644 --- a/thirdparty/lz4/CMakeLists.txt +++ b/thirdparty/lz4/CMakeLists.txt @@ -94,6 +94,7 @@ if(MSVC) URL "${CMAKE_CURRENT_SOURCE_DIR}/lz4-1.9.4" SOURCE_SUBDIR build/cmake CMAKE_ARGS + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DCMAKE_INSTALL_PREFIX=${EXTERNAL_BINARY_DIR}/usr/local -DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_MSVC_RUNTIME_LIBRARY=${CMAKE_MSVC_RUNTIME_LIBRARY} diff --git a/thirdparty/sparsehash/CMakeLists.txt b/thirdparty/sparsehash/CMakeLists.txt index c389505fc..40920062d 100644 --- a/thirdparty/sparsehash/CMakeLists.txt +++ b/thirdparty/sparsehash/CMakeLists.txt @@ -7,7 +7,10 @@ endif() if(MSVC) set(SPARSEHASH_WINDOWS_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/sparsehash.windows.patch) - apply_patch_once("sparsehash.windows.patch" "${DESTINATION_DIR}" "${SPARSEHASH_WINDOWS_PATCH}") + apply_patch_once( + "sparsehash.windows.patch" + "${CMAKE_CURRENT_SOURCE_DIR}/sparsehash-2.0.4" + "${SPARSEHASH_WINDOWS_PATCH}") endif() add_library(sparsehash INTERFACE) From 2d552943a024654966035e4b7ed5880913c3b47c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=A3=9E?= Date: Thu, 14 May 2026 11:38:45 -0400 Subject: [PATCH 4/5] docs(java): add quickstart and release notes --- README.md | 13 ++ java/zvec-java/README.md | 213 ++++++++++++++++++ .../examples/quickstart-jni/README.md | 29 +++ .../zvec-java/examples/quickstart-jni/pom.xml | 45 ++++ .../java/org/zvec/demo/QuickStartDemo.java | 117 ++++++++++ 5 files changed, 417 insertions(+) create mode 100644 java/zvec-java/README.md create mode 100644 java/zvec-java/examples/quickstart-jni/README.md create mode 100644 java/zvec-java/examples/quickstart-jni/pom.xml create mode 100644 java/zvec-java/examples/quickstart-jni/src/main/java/org/zvec/demo/QuickStartDemo.java diff --git a/README.md b/README.md index 09e5f90e5..5fb728fb1 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,19 @@ npm install @zvec/zvec If you prefer to build Zvec from source, please check the [Building from Source](https://zvec.org/en/docs/db/build/) guide. +### Java (Preview) + +Java preview bindings live in [`java/zvec-java`](./java/zvec-java). + +Current validation target: + +- Java 25 +- macOS ARM64 + +The recommended Java API is the fluent layer (`ZvecSchemas`, `ZvecSearch`), while advanced users can still use the compatibility layer for direct `CollectionSchema`, `VectorSchema`, and `VectorQuery` control. + +See [`java/zvec-java/README.md`](./java/zvec-java/README.md) for build and usage instructions. + ## ⚡ One-Minute Example ```python diff --git a/java/zvec-java/README.md b/java/zvec-java/README.md new file mode 100644 index 000000000..09d4162c0 --- /dev/null +++ b/java/zvec-java/README.md @@ -0,0 +1,213 @@ +# zvec-java + +Java bindings for zvec on the same desktop platforms supported by the native zvec build. + +The recommended Java API is the fluent layer: `ZvecSchemas` for schema construction and `ZvecSearch` for query construction. + +## Artifacts + +- `org.zvec:zvec-java-jni`: JDK 11+ backend using JNI. This is the default choice for new Java users. +- `org.zvec:zvec-java-ffm`: JDK 25 backend using the Foreign Function & Memory API. +- `org.zvec:zvec-java-api`: public API and shared Java implementation. Backend artifacts bring this transitively. + +The old `org.zvec:zvec-java` compatibility coordinate has been removed. Existing FFM users should depend on `org.zvec:zvec-java-ffm` directly. + +Put exactly one backend artifact on the runtime classpath. If both JNI and FFM are present, startup fails unless `-Dorg.zvec.backend=jni` or `-Dorg.zvec.backend=ffm` is set. + +## Requirements + +- Java 25 for a full `java/zvec-java` reactor build +- Java 11+ for `zvec-java-jni` +- Java 25 for `zvec-java-ffm` +- Maven 3.8+ +- CMake available on `PATH` + +Native artifacts are packaged under `META-INF/native/`. Supported platform ids are +`darwin-aarch64`, `darwin-x86_64`, `linux-aarch64`, `linux-x86_64`, and `windows-x86_64`. +The Maven build uses `host` detection by default. Build each native package on a matching host +runner, or pass `-Dzvec.native.platform=` in that runner to make the package id explicit. + +## Build + +```bash +source "$HOME/.sdkman/bin/sdkman-init.sh" +cd java/zvec-java + +JAVA_HOME="$HOME/.sdkman/candidates/java/25.0.2-oracle" \ + mvn test + +# JDK 11 JNI path +JAVA_HOME="$HOME/.sdkman/candidates/java/11.0.26-amzn" \ + mvn -pl zvec-java-jni -am test + +# JDK 25 FFM path +JAVA_HOME="$HOME/.sdkman/candidates/java/25.0.2-oracle" \ + mvn -pl zvec-java-ffm -am test + +# Explicit platform package on a Linux x86_64 CI runner +JAVA_HOME="$HOME/.sdkman/candidates/java/11.0.26-amzn" \ + mvn -pl zvec-java-jni -am test -Dzvec.native.platform=linux-x86_64 +``` + +## Multi-Platform Release Artifacts + +Do not commit built jars or native libraries to git. Keep `target/` local while preparing +release assets, then upload the assembled jars to GitHub Releases or publish them to a +Maven repository. + +The release flow is: + +1. Build each native platform on a matching runner or machine. +2. Copy the resulting `META-INF/native/` files into the backend module's + `target/classes` directory. +3. Run `mvn package -DskipTests` without `clean` to preserve the copied platform files. +4. Upload `zvec-java-jni/target/zvec-java-jni-*.jar` and + `zvec-java-ffm/target/zvec-java-ffm-*.jar` as release assets. + +The JNI jar should contain `zvec_c_api` and `zvec_java_jni` for each packaged platform. +The FFM jar only needs `zvec_c_api` for each packaged platform. + +## Example + +An executable JNI quickstart is available in `examples/quickstart-jni`. It consumes +`org.zvec:zvec-java-jni` as a normal Maven dependency. + +```bash +cd java/zvec-java +mvn -pl zvec-java-jni -am install -DskipTests + +cd examples/quickstart-jni +mvn compile exec:java +``` + +## Quick Start + +```java +import java.util.List; +import org.zvec.Doc; +import org.zvec.Collection; +import org.zvec.CollectionSchema; +import org.zvec.Zvec; +import org.zvec.ZvecSchemas; +import org.zvec.ZvecSearch; + +CollectionSchema schema = + ZvecSchemas.collection("docs").string("title").vector("embedding", 4).balanced().build(); + +try (Collection collection = Zvec.createAndOpen("./docs", schema)) { + collection.insert( + List.of( + Doc.of("doc_1").field("title", "alpha").vector("embedding", new float[] {1f, 0f, 0f, 0f}), + Doc.of("doc_2").field("title", "beta").vector("embedding", new float[] {0f, 1f, 0f, 0f}))); + + List results = + collection.query( + ZvecSearch.vector("embedding", new float[] {1f, 0f, 0f, 0f}) + .topK(2) + .project("title") + .build()); +} +``` + +## Common Tuning + +Use these fluent tuning methods immediately after `vector(name, dimension)`: + +- `fast()` when you want the fastest index build and can trade off some recall +- `balanced()` for the default middle ground +- `accurate()` when search quality matters more than build speed +- `expectedDocCount(...)` when you know the collection size ahead of time + +Example: + +```java +CollectionSchema schema = + ZvecSchemas.collection("docs") + .string("title") + .vector("embedding", 1536) + .expectedDocCount(1_000_000L) + .balanced() + .build(); +``` + +## Advanced Control + +If you need direct HNSW configuration, use the compatibility layer: + +```java +import org.zvec.HnswIndexParams; +import org.zvec.HnswQueryParams; +import org.zvec.VectorQuery; +import org.zvec.VectorSchema; + +VectorSchema schema = + new VectorSchema("embedding", org.zvec.DataType.VECTOR_FP32, 1536) + .withHnswIndex(new HnswIndexParams(32, 300)); + +VectorQuery query = + VectorQuery.of("embedding", new float[] {1f, 0f, 0f, 0f}) + .hnsw(new HnswQueryParams(128, 0.0f, false, true)); +``` + +## Encrypted Fields + +Mark a string field as encrypted in the schema; insert and query call sites stay identical to plaintext code. + +```java +import org.zvec.crypto.KeyProvider; + +KeyProvider keys = keyId -> myKms.fetchKey(keyId); // 32 bytes for AES-256 +CollectionSchema schema = ZvecSchemas.collection("docs") + .string("title") + .string("body").encrypted("body-key-v1") + .vector("embed", 768).balanced() + .build(); + +try (Collection col = Zvec.createAndOpen("./docs", schema, keys)) { + col.insert(List.of( + Doc.of("d1").field("title", "alpha") + .field("body", "plaintext stays plaintext at the call site") + .vector("embed", v))); + + List results = col.query( + ZvecSearch.vector("embed", q).topK(10).project("title", "body").build()); + + // results.get(0).fields().get("body") is already plaintext +} +``` + +Reopen with the same provider: + +```java +try (Collection col = Zvec.openWithKeys("./docs", keys)) { ... } +``` + +Key rotation (new writes use the new keyId; existing records keep their original): + +```java +col.setActiveKeyId("body", "body-key-v2"); +``` + +**Key things to know:** + +- AES-256-GCM with a 12-byte random nonce per field per record. Nonce reuse under the same key would be catastrophic; use a `SecureRandom`-backed flow or never reuse keys across processes that don't coordinate nonces. +- `id`, field name, and collection name are bound into AAD automatically. Moving ciphertext between docs/fields/collections is detected. +- Queries cannot filter on encrypted fields. `ZvecSearch.filter("body = 'x'")` throws `IllegalArgumentException`. +- Decryption failures (tamper, missing key, AAD mismatch) abort the entire query — fail-loud by design. +- The library never logs key material, plaintext, or ciphertext. Caller adds logging in their own try/catch as needed. +- A static-key form `.encrypted(keyId, byte[])` is available for tests and demos. Key bytes are never persisted; reopening still requires a `KeyProvider`. +- Sidecar metadata lives at `/_zvec_enc.json`. Don't hand-edit unless you know what you're doing. + +## Scope + +Current support: + +- create/open collection +- insert documents +- dense float vector query +- string / bool / int64 / double scalar fields + +Deferred: + +- update, upsert, delete, fetch +- sparse vectors diff --git a/java/zvec-java/examples/quickstart-jni/README.md b/java/zvec-java/examples/quickstart-jni/README.md new file mode 100644 index 000000000..93a2dea3e --- /dev/null +++ b/java/zvec-java/examples/quickstart-jni/README.md @@ -0,0 +1,29 @@ +# zvec-java JNI Quickstart + +This example consumes `org.zvec:zvec-java-jni` as a normal Maven dependency. +It does not build native code itself; it expects the zvec-java jars to be installed +locally or available from a Maven repository. + +## Run + +From the repository root: + +```bash +cd java/zvec-java +mvn -pl zvec-java-jni -am install -DskipTests + +cd examples/quickstart-jni +mvn compile exec:java +``` + +The demo creates a collection, inserts a few documents, runs a vector search, +closes the collection, reopens it from disk, and runs the same search again. + +To use a stable collection path instead of a temporary directory: + +```bash +mvn compile exec:java -Dexec.args="/tmp/zvec-java-demo" +``` + +Use JDK 11 or newer for this JNI example. For the FFM backend, switch the +dependency to `org.zvec:zvec-java-ffm` and run on JDK 25. diff --git a/java/zvec-java/examples/quickstart-jni/pom.xml b/java/zvec-java/examples/quickstart-jni/pom.xml new file mode 100644 index 000000000..45acb2b65 --- /dev/null +++ b/java/zvec-java/examples/quickstart-jni/pom.xml @@ -0,0 +1,45 @@ + + 4.0.0 + + org.zvec.examples + zvec-java-quickstart-jni + 0.0.1-SNAPSHOT + zvec-java-quickstart-jni + + + 11 + UTF-8 + 0.0.1-SNAPSHOT + + + + + org.zvec + zvec-java-jni + ${zvec.version} + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.14.0 + + ${maven.compiler.release} + + + + org.codehaus.mojo + exec-maven-plugin + 3.5.0 + + org.zvec.demo.QuickStartDemo + + + + + diff --git a/java/zvec-java/examples/quickstart-jni/src/main/java/org/zvec/demo/QuickStartDemo.java b/java/zvec-java/examples/quickstart-jni/src/main/java/org/zvec/demo/QuickStartDemo.java new file mode 100644 index 000000000..a8cd97832 --- /dev/null +++ b/java/zvec-java/examples/quickstart-jni/src/main/java/org/zvec/demo/QuickStartDemo.java @@ -0,0 +1,117 @@ +package org.zvec.demo; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Comparator; +import java.util.List; +import org.zvec.Collection; +import org.zvec.CollectionSchema; +import org.zvec.Doc; +import org.zvec.Zvec; +import org.zvec.ZvecSchemas; +import org.zvec.ZvecSearch; + +public final class QuickStartDemo { + private static final String VECTOR_FIELD = "embedding"; + + private QuickStartDemo() {} + + public static void main(String[] args) throws Exception { + Path collectionPath = collectionPath(args); + boolean temporary = args.length == 0; + if (!temporary && Files.exists(collectionPath)) { + throw new IllegalArgumentException( + "collection path already exists: " + collectionPath + + System.lineSeparator() + + "Choose an empty path or delete the existing directory first."); + } + + CollectionSchema schema = + ZvecSchemas.collection("demo_docs") + .string("title") + .string("category") + .vector(VECTOR_FIELD, 4) + .balanced() + .build(); + + try (Collection collection = Zvec.createAndOpen(collectionPath.toString(), schema)) { + int inserted = + collection.insert( + List.of( + Doc.of("doc_1") + .field("title", "Vector search basics") + .field("category", "guide") + .vector(VECTOR_FIELD, new float[] {1.0f, 0.0f, 0.0f, 0.0f}), + Doc.of("doc_2") + .field("title", "Approximate nearest neighbors") + .field("category", "guide") + .vector(VECTOR_FIELD, new float[] {0.8f, 0.2f, 0.0f, 0.0f}), + Doc.of("doc_3") + .field("title", "Release checklist") + .field("category", "ops") + .vector(VECTOR_FIELD, new float[] {0.0f, 1.0f, 0.0f, 0.0f}), + Doc.of("doc_4") + .field("title", "Encrypted fields") + .field("category", "security") + .vector(VECTOR_FIELD, new float[] {0.0f, 0.0f, 1.0f, 0.0f}))); + collection.flush(); + System.out.println("Inserted " + inserted + " documents into " + collectionPath); + printResults("Initial query", search(collection)); + } + + try (Collection reopened = Zvec.open(collectionPath.toString())) { + printResults("Query after reopen", search(reopened)); + } + + if (temporary) { + deleteRecursively(collectionPath.getParent()); + } + } + + private static Path collectionPath(String[] args) throws Exception { + if (args.length > 0) { + return Paths.get(args[0]).toAbsolutePath().normalize(); + } + Path runDir = Files.createTempDirectory("zvec-java-demo-"); + return runDir.resolve("docs"); + } + + private static List search(Collection collection) { + return collection.query( + ZvecSearch.vector(VECTOR_FIELD, new float[] {1.0f, 0.0f, 0.0f, 0.0f}) + .topK(3) + .project("title", "category") + .build()); + } + + private static void printResults(String label, List docs) { + System.out.println(); + System.out.println(label + ":"); + for (Doc doc : docs) { + System.out.printf( + " %s score=%.6f title=\"%s\" category=%s%n", + doc.id(), + doc.score(), + doc.fields().get("title"), + doc.fields().get("category")); + } + } + + private static void deleteRecursively(Path path) throws Exception { + if (path == null || !Files.exists(path)) { + return; + } + try (java.util.stream.Stream stream = Files.walk(path)) { + stream.sorted(Comparator.reverseOrder()).forEach(QuickStartDemo::deleteOne); + } + } + + private static void deleteOne(Path path) { + try { + Files.deleteIfExists(path); + } catch (Exception e) { + throw new RuntimeException("failed to delete " + path, e); + } + } +} From fef96cbb2bd634afe1d11faffe19c860510ad1b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=A3=9E?= Date: Sat, 16 May 2026 10:55:04 -0400 Subject: [PATCH 5/5] build(java): add linux-aarch64 glibc 2.28 build image Mirrors Dockerfile.linux_x64_glibc228 for arm64: ubuntu:18.10 (cosmic) on ports.ubuntu.com / old-releases, gcc-9 from focal, glibc 2.28, CMake 3.30.0, Miniforge with py310/311/312 envs. Used to produce linux-aarch64 native libraries for the zvec-java-jni and zvec-java-ffm artifacts. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../docker/Dockerfile.linux_aarch64_glibc228 | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 .github/workflows/docker/Dockerfile.linux_aarch64_glibc228 diff --git a/.github/workflows/docker/Dockerfile.linux_aarch64_glibc228 b/.github/workflows/docker/Dockerfile.linux_aarch64_glibc228 new file mode 100644 index 000000000..42704d13a --- /dev/null +++ b/.github/workflows/docker/Dockerfile.linux_aarch64_glibc228 @@ -0,0 +1,83 @@ +# ============================================================================= +# Dockerfile.linux_aarch64_glibc228 +# Purpose: Ubuntu 18.10 gcc-9 + glibc 2.28 + CMake 3.30.0 + PyBind11 build environment (arm64) +# Warning: ubuntu:18.10 is EOL; use only for glibc 2.28 compatibility testing. +# ============================================================================= + +# Use official Ubuntu 18.10 (Cosmic Cuttlefish), aarch64 variant +# glibc version: 2.28 (confirmed via `ldd --version`) +FROM --platform=linux/arm64 ubuntu:18.10 + +# Replace Ubuntu mirror with old-releases.ubuntu.com/ubuntu-ports for older glibc compatibility +# Note: arm64 archives live on ports.ubuntu.com, not archive.ubuntu.com +RUN sed -i 's|http://ports.ubuntu.com/ubuntu-ports|http://old-releases.ubuntu.com/ubuntu-ports|g' /etc/apt/sources.list + +# Add Ubuntu 20.04 (focal) repo for GCC 9 ONLY +RUN echo "deb http://ports.ubuntu.com/ubuntu-ports/ focal main universe" >> /etc/apt/sources.list && \ + echo "deb http://ports.ubuntu.com/ubuntu-ports/ focal-security main universe" >> /etc/apt/sources.list + +# Prevent interactive prompts & set non-root user +ENV DEBIAN_FRONTEND=noninteractive \ + TZ=Etc/UTC + +# Create non-root user for safety (optional but recommended) +RUN useradd -m -u 1000 builder && \ + mkdir -p /workspace && chown builder:builder /workspace + +# Install base system dependencies +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + build-essential \ + gcc-9 g++-9 \ + ninja-build git curl ca-certificates vim wget lcov gnupg clang-format-18\ + rsync lsb-release \ + uuid-dev zlib1g-dev libssl-dev libffi-dev \ + pybind11-dev && \ + update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 90 \ + --slave /usr/bin/g++ g++ /usr/bin/g++-9 && \ + rm -rf /var/lib/apt/lists/* + +# Install Miniforge (Conda) as root, then assign to builder +ENV MINIFORGE_VERSION="latest" +ENV MINIFORGE_HOME="/opt/miniforge3" + +RUN curl -sSL "https://github.com/conda-forge/miniforge/releases/${MINIFORGE_VERSION}/download/Miniforge3-Linux-aarch64.sh" -o miniforge.sh && \ + bash miniforge.sh -b -p ${MINIFORGE_HOME} && \ + rm miniforge.sh && \ + chown -R builder:builder ${MINIFORGE_HOME} + +# Switch to non-root user +USER builder +ENV PATH="${MINIFORGE_HOME}/bin:${PATH}" +WORKDIR /workspace + +# Create conda envs for supported Python versions +RUN conda create -n py310 python=3.10 -y && \ + conda create -n py311 python=3.11 -y && \ + conda create -n py312 python=3.12 -y +RUN conda clean --all -f -y + +# Install CMake 3.30.0 from Kitware official binary +# Ref: https://github.com/Kitware/CMake/releases/tag/v3.30.0 +RUN mkdir -p /tmp/cmake && cd /tmp/cmake && \ + curl -sSL -o cmake.tar.gz \ + "https://github.com/Kitware/CMake/releases/download/v3.30.0/cmake-3.30.0-linux-aarch64.tar.gz" && \ + tar -xzf cmake.tar.gz --strip-components=1 -C /tmp/cmake && \ + mkdir -p /home/builder/.local && \ + mv * /home/builder/.local/ && \ + chown -R builder:builder /home/builder/.local && \ + rm -rf /tmp/cmake + +# Add CMake to PATH +ENV PATH="/home/builder/.local/bin:${PATH}" + +# Verify installations +RUN cmake --version && \ + conda info && \ + conda env list && \ + python --version && \ + gcc --version && \ + ldd --version | head -n1 + +# Final setup +WORKDIR /workspace