Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions tok/hnsw/ef_recall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"encoding/binary"
"math"
"sync"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -209,3 +210,94 @@ func TestHNSWDistanceThreshold_Cosine(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []uint64{1}, res)
}

// deadNodesTxn is a minimal index.Txn for exercising removeDeadNodes: Get reads
// from an in-memory map and StartTs is fixed. The rest are unused no-ops.
type deadNodesTxn struct {
startTs uint64
data map[string][]byte
}

func (t *deadNodesTxn) StartTs() uint64 { return t.startTs }
func (t *deadNodesTxn) Get(key []byte) ([]byte, error) { return t.data[string(key)], nil }
func (t *deadNodesTxn) GetWithLockHeld(key []byte) ([]byte, error) { return t.data[string(key)], nil }
func (t *deadNodesTxn) Find([]byte, func([]byte) bool) (uint64, error) { return 0, nil }
func (t *deadNodesTxn) AddMutation(context.Context, []byte, *index.KeyValue) error {
return nil
}
func (t *deadNodesTxn) AddMutationWithLockHeld(context.Context, []byte, *index.KeyValue) error {
return nil
}
func (t *deadNodesTxn) LockKey([]byte) {}
func (t *deadNodesTxn) UnlockKey([]byte) {}

// TestRemoveDeadNodesRefreshesAcrossTimestamps guards the fix for the
// load-once-never-refresh bug: the dead-node set must be re-read when the
// transaction timestamp advances, while staying stable within a single snapshot.
func TestRemoveDeadNodesRefreshesAcrossTimestamps(t *testing.T) {
ph := &persistentHNSW[float64]{vecDead: ConcatStrings("0-dead", VecDead)}
deadKey := string(DataKey(ph.vecDead, 1))
store := map[string][]byte{}

// ts=10: nothing is dead yet, so nothing is filtered.
tc1 := NewTxnCache(&deadNodesTxn{startTs: 10, data: store}, 10)
out, err := ph.removeDeadNodes([]uint64{1, 2, 3}, tc1)
require.NoError(t, err)
require.Equal(t, []uint64{1, 2, 3}, out)

// A delete makes uid 2 dead. Reusing the same snapshot (ts=10) must NOT see
// it — reads at a fixed StartTs are snapshot-consistent.
store[deadKey] = []byte("[2]")
out, err = ph.removeDeadNodes([]uint64{1, 2, 3}, tc1)
require.NoError(t, err)
require.Equal(t, []uint64{1, 2, 3}, out)

// A newer transaction (ts=20) MUST observe the deletion. Before the fix the
// cache was loaded once and this still returned {1,2,3}.
tc2 := NewTxnCache(&deadNodesTxn{startTs: 20, data: store}, 20)
out, err = ph.removeDeadNodes([]uint64{1, 2, 3}, tc2)
require.NoError(t, err)
require.Equal(t, []uint64{1, 3}, out)
}

// TestRemoveDeadNodesSnapshotIsolation verifies the shared cache only advances
// in time: once a newer snapshot is cached, an older-ts caller must still see
// its own (older) view, not the newer set of deletions.
func TestRemoveDeadNodesSnapshotIsolation(t *testing.T) {
ph := &persistentHNSW[float64]{vecDead: ConcatStrings("0-dead", VecDead)}
deadKey := string(DataKey(ph.vecDead, 1))

// Newer txn (ts=20) sees uid 2 as dead and installs the cache at ts=20.
tcNew := NewTxnCache(&deadNodesTxn{startTs: 20, data: map[string][]byte{deadKey: []byte("[2]")}}, 20)
out, err := ph.removeDeadNodes([]uint64{1, 2, 3}, tcNew)
require.NoError(t, err)
require.Equal(t, []uint64{1, 3}, out)

// Older txn (ts=10), at whose snapshot uid 2 is NOT yet dead, must not be
// affected by the newer cached snapshot.
tcOld := NewTxnCache(&deadNodesTxn{startTs: 10, data: map[string][]byte{}}, 10)
out, err = ph.removeDeadNodes([]uint64{1, 2, 3}, tcOld)
require.NoError(t, err)
require.Equal(t, []uint64{1, 2, 3}, out)
}

// TestLoadDeadNodesConcurrent exercises the lock-free publication path under the
// race detector: many goroutines at mixed timestamps loading concurrently.
func TestLoadDeadNodesConcurrent(t *testing.T) {
ph := &persistentHNSW[float64]{vecDead: ConcatStrings("0-dead", VecDead)}
deadKey := string(DataKey(ph.vecDead, 1))
data := map[string][]byte{deadKey: []byte("[2]")}

var wg sync.WaitGroup
for g := range 32 {
wg.Add(1)
go func(ts uint64) {
defer wg.Done()
tc := NewTxnCache(&deadNodesTxn{startTs: ts, data: data}, ts)
out, err := ph.removeDeadNodes([]uint64{1, 2, 3}, tc)
require.NoError(t, err)
require.Equal(t, []uint64{1, 3}, out)
}(uint64(10 + g%4)) // timestamps 10..13
}
wg.Wait()
}
95 changes: 72 additions & 23 deletions tok/hnsw/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -743,40 +743,89 @@ func (ph *persistentHNSW[T]) addNeighbors(ctx context.Context, tc *TxnCache,

// removeDeadNodes(nnEdges, tc) removes dead nodes from nnEdges and returns the new nnEdges
func (ph *persistentHNSW[T]) removeDeadNodes(nnEdges []uint64, tc *TxnCache) ([]uint64, error) {
// TODO add a path to delete deadNodes
if ph.deadNodes == nil {
data, err := getDataFromKeyWithCacheType(ph.vecDead, 1, tc)
if err != nil && !errors.Is(err, errFetchingPostingList) {
return []uint64{}, err
}

var deadNodes []uint64
if data != nil { // if dead nodes exist, convert to []uint64
deadNodes, err = ParseEdges(string(data))
if err != nil {
return []uint64{}, err
}
}

ph.deadNodes = make(map[uint64]struct{})
for _, n := range deadNodes {
ph.deadNodes[n] = struct{}{}
}
deadNodes, err := ph.loadDeadNodes(tc)
if err != nil {
return []uint64{}, err
}
if len(ph.deadNodes) == 0 {
if len(deadNodes) == 0 {
return nnEdges, nil
}

var diff []uint64
diff := make([]uint64, 0, len(nnEdges))
for _, s := range nnEdges {
if _, ok := ph.deadNodes[s]; !ok {
if _, ok := deadNodes[s]; !ok {
diff = append(diff, s)
continue
}
}
return diff, nil
}

// loadDeadNodes returns the set of tombstoned (deleted) vector UIDs visible at
// the cache's read timestamp.
//
// The dead set is persisted as a single posting (DataKey(vecDead, 1)) that grows
// as vectors are deleted. It used to be loaded once and never refreshed, so any
// vector deleted after the first call stayed invisible to the neighbour filter
// for the lifetime of the index instance — dead UIDs leaked back into edge
// lists. We instead cache it as an immutable snapshot tagged with its read
// timestamp: a rebuild streams every key at a single StartTs, so the JSON is
// parsed once and reused across the many removeDeadNodes calls per insert, while
// later transactions (a newer StartTs) reload and observe new deletions.
//
// The shared cache only ever advances in time. A transaction never observes
// deletions newer than its own snapshot: if a newer snapshot is already cached,
// the caller gets its own freshly-loaded set without overwriting the cache. The
// returned map is immutable, so callers read it without locking.
//
// Concurrent loaders at the same ts each build a set and race to publish; the
// losers reuse the winner's. That one-time duplicate parse is bounded (one per
// rebuild goroutine) and the posting read is in-memory during a rebuild, so no
// singleflight is warranted.
func (ph *persistentHNSW[T]) loadDeadNodes(tc *TxnCache) (map[uint64]struct{}, error) {
ts := tc.Ts()
if cur := ph.deadNodes.Load(); cur != nil && cur.ts == ts {
return cur.set, nil
}

data, err := getDataFromKeyWithCacheType(ph.vecDead, 1, tc)
if err != nil && !errors.Is(err, errFetchingPostingList) {
return nil, err
}

var deadNodes []uint64
if data != nil { // if dead nodes exist, convert to []uint64
deadNodes, err = ParseEdges(string(data))
if err != nil {
return nil, err
}
}

loaded := make(map[uint64]struct{}, len(deadNodes))
for _, n := range deadNodes {
loaded[n] = struct{}{}
}
snap := &deadSnapshot{ts: ts, set: loaded}

for {
cur := ph.deadNodes.Load()
switch {
case cur != nil && cur.ts == ts:
// A concurrent loader already published a snapshot for this ts.
// Reads at a fixed ts are deterministic, so cur.set equals what we
// loaded; reuse the shared one and drop ours.
return cur.set, nil
case cur != nil && cur.ts > ts:
// A newer snapshot is cached. Serve our own ts-scoped set without
// installing it, so older readers keep snapshot isolation.
return loaded, nil
default:
if ph.deadNodes.CompareAndSwap(cur, snap) {
return loaded, nil
}
}
}
}

func Uint64ToBytes(key uint64) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, key)
Expand Down
17 changes: 16 additions & 1 deletion tok/hnsw/persistent_hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"fmt"
"strings"
"sync/atomic"
"time"

c "github.com/dgraph-io/dgraph/v25/tok/constraints"
Expand All @@ -31,7 +32,21 @@ type persistentHNSW[T c.Float] struct {
// nodeAllEdges[65443][1][3] indicates the 3rd neighbor in the first
// layer for UUID 65443. The result will be a neighboring UUID.
nodeAllEdges map[uint64][][]uint64
deadNodes map[uint64]struct{}

// deadNodes caches the tombstoned (deleted) vector set — the persisted
// vecDead posting — as an immutable snapshot tagged with the read timestamp
// it was loaded at, so it is refreshed when the snapshot advances (it used to
// be loaded once and never refreshed; see loadDeadNodes). Published
// atomically because the index instance is shared across the goroutines that
// drive an index rebuild.
deadNodes atomic.Pointer[deadSnapshot]
}

// deadSnapshot is an immutable view of the dead-node set as of a read timestamp.
// It is never mutated after construction, so readers use set without locking.
type deadSnapshot struct {
ts uint64
set map[uint64]struct{}
}

func GetPersistantOptions[T c.Float](o opt.Options) string {
Expand Down
Loading