Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bindings/cpp/include/svs/runtime/vamana_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ struct SVS_RUNTIME_API VamanaIndex {
IDFilter* filter = nullptr
) const noexcept = 0;

// Compute distance between stored vector `id` and `query` (dim floats).
virtual Status
get_distance(double* distance, size_t id, const float* query) const noexcept = 0;

// Reconstruct `n` vectors by ID into `output` buffer (n * dim floats).
virtual Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept = 0;

// Utility function to check storage kind support
static Status check_storage_kind(StorageKind storage_kind) noexcept;

Expand Down
9 changes: 9 additions & 0 deletions bindings/cpp/src/dynamic_vamana_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ struct DynamicVamanaIndexManagerBase : public DynamicVamanaIndex {
Status save(std::ostream& out) const noexcept override {
return runtime_error_wrapper([&] { impl_->save(out); });
}

Status
get_distance(double* distance, size_t id, const float* query) const noexcept override {
return runtime_error_wrapper([&] { *distance = impl_->get_distance(id, query); });
}

Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override {
return runtime_error_wrapper([&] { impl_->reconstruct_at(n, ids, output); });
}
};
} // namespace

Expand Down
17 changes: 17 additions & 0 deletions bindings/cpp/src/dynamic_vamana_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,23 @@ class DynamicVamanaIndexImpl {
return remove(ids_to_delete);
}

double get_distance(size_t id, const float* query) const {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
auto query_span = std::span<const float>(query, dim_);
return impl_->get_distance(id, query_span);
}

void reconstruct_at(size_t n, const size_t* ids, float* output) {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
svs::data::SimpleDataView<float> dst{output, n, dim_};
std::span<const uint64_t> id_span{reinterpret_cast<const uint64_t*>(ids), n};
impl_->reconstruct_at(dst, id_span);
}

void reset() {
impl_.reset();
ntotal_soft_deleted = 0;
Expand Down
9 changes: 9 additions & 0 deletions bindings/cpp/src/vamana_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ struct VamanaIndexManagerBase : public VamanaIndex {
Status save(std::ostream& out) const noexcept override {
return runtime_error_wrapper([&] { impl_->save(out); });
}

Status
get_distance(double* distance, size_t id, const float* query) const noexcept override {
return runtime_error_wrapper([&] { *distance = impl_->get_distance(id, query); });
}

Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override {
return runtime_error_wrapper([&] { impl_->reconstruct_at(n, ids, output); });
}
};
} // namespace

Expand Down
17 changes: 17 additions & 0 deletions bindings/cpp/src/vamana_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,23 @@ class VamanaIndexImpl {
}
}

double get_distance(size_t id, const float* query) const {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
auto query_span = std::span<const float>(query, dim_);
return get_impl()->get_distance(id, query_span);
}

void reconstruct_at(size_t n, const size_t* ids, float* output) {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
svs::data::SimpleDataView<float> dst{output, n, dim_};
std::span<const uint64_t> id_span{reinterpret_cast<const uint64_t*>(ids), n};
get_impl()->reconstruct_at(dst, id_span);
}

void reset() { impl_.reset(); }

void save(std::ostream& out) const {
Expand Down
138 changes: 138 additions & 0 deletions bindings/cpp/tests/runtime_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,141 @@ CATCH_TEST_CASE("RangeSearchFunctionalStatic", "[runtime][static_vamana]") {

svs::runtime::v0::VamanaIndex::destroy(index);
}

CATCH_TEST_CASE("GetDistanceDynamic", "[runtime]") {
const auto& test_data = get_test_data();
svs::runtime::v0::DynamicVamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::DynamicVamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

std::vector<size_t> labels(test_n);
std::iota(labels.begin(), labels.end(), 0);
status = index->add(test_n, labels.data(), test_data.data());
CATCH_REQUIRE(status.ok());

// Self-distance should be approximately 0
double dist = -1.0;
const float* vec0 = test_data.data();
status = index->get_distance(&dist, 0, vec0);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist < 1e-6);

// Distance to a different vector should be positive
const float* vec1 = test_data.data() + test_d;
status = index->get_distance(&dist, 0, vec1);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist > 0.0);

svs::runtime::v0::DynamicVamanaIndex::destroy(index);
}

CATCH_TEST_CASE("GetDistanceStatic", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
svs::runtime::v0::VamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::VamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

status = index->add(test_n, test_data.data());
CATCH_REQUIRE(status.ok());

// Self-distance should be approximately 0
double dist = -1.0;
const float* vec0 = test_data.data();
status = index->get_distance(&dist, 0, vec0);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist < 1e-6);

// Distance to a different vector should be positive
const float* vec1 = test_data.data() + test_d;
status = index->get_distance(&dist, 0, vec1);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist > 0.0);

svs::runtime::v0::VamanaIndex::destroy(index);
}

CATCH_TEST_CASE("ReconstructAtDynamic", "[runtime]") {
const auto& test_data = get_test_data();
svs::runtime::v0::DynamicVamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::DynamicVamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

std::vector<size_t> labels(test_n);
std::iota(labels.begin(), labels.end(), 0);
status = index->add(test_n, labels.data(), test_data.data());
CATCH_REQUIRE(status.ok());

// Reconstruct first 5 vectors
constexpr size_t nrecon = 5;
std::vector<size_t> ids(nrecon);
std::iota(ids.begin(), ids.end(), 0);
std::vector<float> output(nrecon * test_d, 0.0f);

status = index->reconstruct_at(nrecon, ids.data(), output.data());
CATCH_REQUIRE(status.ok());

// For FP32 storage, reconstructed vectors should match originals exactly
for (size_t i = 0; i < nrecon; ++i) {
for (size_t j = 0; j < test_d; ++j) {
CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]);
}
}

svs::runtime::v0::DynamicVamanaIndex::destroy(index);
}

CATCH_TEST_CASE("ReconstructAtStatic", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
svs::runtime::v0::VamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::VamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

status = index->add(test_n, test_data.data());
CATCH_REQUIRE(status.ok());

// Reconstruct first 5 vectors
constexpr size_t nrecon = 5;
std::vector<size_t> ids(nrecon);
std::iota(ids.begin(), ids.end(), 0);
std::vector<float> output(nrecon * test_d, 0.0f);

status = index->reconstruct_at(nrecon, ids.data(), output.data());
CATCH_REQUIRE(status.ok());

// For FP32 storage, reconstructed vectors should match originals exactly
for (size_t i = 0; i < nrecon; ++i) {
for (size_t j = 0; j < test_d; ++j) {
CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]);
}
}

svs::runtime::v0::VamanaIndex::destroy(index);
}
Loading