From 95dc3a17dc7f1c2c99bfdecc8938688f7b89d9d5 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 4 Mar 2026 16:26:46 -0500 Subject: [PATCH 01/12] feat(search): add BM25 ranked text search Add BM25 relevance-ranked text search to Dgraph, enabling users to query text predicates and receive results ordered by relevance score instead of boolean matching. Implementation: - New BM25 tokenizer using the fulltext pipeline (normalize, stopwords, stem) that preserves term frequencies for TF counting - BM25-specific index storage: per-term TF posting lists, doc length lists, and corpus statistics (doc count, total terms) - Query execution with full BM25 scoring: score = IDF * (k+1) * tf / (k * (1 - b + b * dl/avgDL) + tf) IDF = log1p((N - df + 0.5) / (df + 0.5)) - DQL syntax: bm25(predicate, "query" [, "k", "b"]) as root func or filter - Schema syntax: @index(bm25) - Parameter validation (k > 0, 0 <= b <= 1) - Early UID intersection for filter-mode performance - All-stopword document and query handling Co-Authored-By: Claude Opus 4.6 --- dql/parser.go | 2 +- posting/index.go | 183 +++++++++++++++++++++++++++++++++ query/common_test.go | 17 +++ query/query_bm25_test.go | 214 ++++++++++++++++++++++++++++++++++++++ tok/tok.go | 43 ++++++++ tok/tok_test.go | 140 +++++++++++++++++++++++++ tok/tokens.go | 25 +++++ worker/task.go | 216 ++++++++++++++++++++++++++++++++++++++- worker/tokens.go | 5 + x/keys.go | 19 ++++ 10 files changed, 861 insertions(+), 3 deletions(-) create mode 100644 query/query_bm25_test.go diff --git a/dql/parser.go b/dql/parser.go index 0dd6e1db7ac..666c3eacaab 100644 --- a/dql/parser.go +++ b/dql/parser.go @@ -1701,7 +1701,7 @@ func validFuncName(name string) bool { switch name { case "regexp", "anyofterms", "allofterms", "alloftext", "anyoftext", "ngram", - "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25": return true } return false diff --git a/posting/index.go b/posting/index.go index ae6c3352a44..88c0e5920a9 100644 --- a/posting/index.go +++ b/posting/index.go @@ -68,6 +68,10 @@ func indexTokens(ctx context.Context, info *indexMutationInfo) ([]string, error) var tokens []string for _, it := range info.tokenizers { + // BM25 tokenizer is handled separately in addBM25IndexMutations. + if it.Identifier() == tok.IdentBM25 { + continue + } toks, err := tok.BuildTokens(sv.Value, tok.GetTokenizerForLang(it, lang)) if err != nil { return tokens, err @@ -179,6 +183,17 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo) } } + // Check if any tokenizer is BM25 and handle separately. + for _, it := range info.tokenizers { + if _, ok := tok.GetTokenizerForLang(it, info.edge.GetLang()).(tok.BM25Tokenizer); ok { + if err := txn.addBM25IndexMutations(ctx, info); err != nil { + return []*pb.DirectedEdge{}, err + } + // Continue to process remaining non-BM25 tokenizers below. + continue + } + } + tokens, err := indexTokens(ctx, info) if err != nil { // This data is not indexable @@ -215,6 +230,174 @@ func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, tok return nil } +// addBM25IndexMutations handles index mutations for the BM25 tokenizer. +// It stores term frequencies, document lengths, and corpus statistics. +func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationInfo) error { + attr := info.edge.Attr + uid := info.edge.Entity + lang := info.edge.GetLang() + + schemaType, err := schema.State().TypeOf(attr) + if err != nil || !schemaType.IsScalar() { + return errors.Errorf("Cannot BM25 index attribute %s of type object.", attr) + } + + sv, err := types.Convert(info.val, schemaType) + if err != nil { + return err + } + + bm25Tok := tok.BM25Tokenizer{} + termFreqs, docLen, err := bm25Tok.TokensWithFrequency(sv.Value, lang) + if err != nil { + return err + } + + // Skip documents that tokenize to zero terms (e.g., all stopwords). + if docLen == 0 { + return nil + } + + if info.op == pb.DirectedEdge_DEL { + // For DELETE: remove uid from all BM25 term posting lists, doc length list, + // and decrement corpus stats. + for term := range termFreqs { + encodedTerm := string([]byte{tok.IdentBM25}) + term + key := x.BM25IndexKey(attr, encodedTerm) + plist, err := txn.cache.GetFromDelta(key) + if err != nil { + return err + } + edge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Op: pb.DirectedEdge_DEL, + } + if err := plist.addMutation(ctx, txn, edge); err != nil { + return err + } + } + // Remove doc length entry. + dlKey := x.BM25DocLenKey(attr) + dlPlist, err := txn.cache.GetFromDelta(dlKey) + if err != nil { + return err + } + dlEdge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Op: pb.DirectedEdge_DEL, + } + if err := dlPlist.addMutation(ctx, txn, dlEdge); err != nil { + return err + } + + // Update corpus stats: decrement doc count and total terms. + return txn.updateBM25Stats(ctx, attr, -1, -int64(docLen)) + } + + // For SET: store term frequencies, doc length, and update corpus stats. + for term, tf := range termFreqs { + encodedTerm := string([]byte{tok.IdentBM25}) + term + key := x.BM25IndexKey(attr, encodedTerm) + plist, err := txn.cache.GetFromDelta(key) + if err != nil { + return err + } + // Store uid in the posting list. The TF is encoded in the Value field. + tfBuf := make([]byte, 4) + binary.BigEndian.PutUint32(tfBuf, tf) + edge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Value: tfBuf, + ValueType: pb.Posting_INT, + Op: pb.DirectedEdge_SET, + } + if err := plist.addMutation(ctx, txn, edge); err != nil { + return err + } + } + + // Store document length. + dlKey := x.BM25DocLenKey(attr) + dlPlist, err := txn.cache.GetFromDelta(dlKey) + if err != nil { + return err + } + dlBuf := make([]byte, 4) + binary.BigEndian.PutUint32(dlBuf, docLen) + dlEdge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Value: dlBuf, + ValueType: pb.Posting_INT, + Op: pb.DirectedEdge_SET, + } + if err := dlPlist.addMutation(ctx, txn, dlEdge); err != nil { + return err + } + + // Update corpus stats: increment doc count by 1 and total terms by docLen. + return txn.updateBM25Stats(ctx, attr, 1, int64(docLen)) +} + +// updateBM25Stats reads the current corpus statistics for a BM25-indexed attribute, +// applies the given deltas, and writes back. +func (txn *Txn) updateBM25Stats(ctx context.Context, attr string, docCountDelta int64, totalTermsDelta int64) error { + statsKey := x.BM25StatsKey(attr) + plist, err := txn.cache.GetFromDelta(statsKey) + if err != nil { + return err + } + + // Read existing stats from posting with uid=1. + var docCount, totalTerms uint64 + val, err := plist.Value(txn.StartTs) + if err == nil && val.Value != nil { + data, ok := val.Value.([]byte) + if ok && len(data) == 16 { + docCount = binary.BigEndian.Uint64(data[0:8]) + totalTerms = binary.BigEndian.Uint64(data[8:16]) + } + } + + // Apply deltas. + if docCountDelta >= 0 { + docCount += uint64(docCountDelta) + } else { + dec := uint64(-docCountDelta) + if dec > docCount { + docCount = 0 + } else { + docCount -= dec + } + } + if totalTermsDelta >= 0 { + totalTerms += uint64(totalTermsDelta) + } else { + dec := uint64(-totalTermsDelta) + if dec > totalTerms { + totalTerms = 0 + } else { + totalTerms -= dec + } + } + + // Write back stats. + statsBuf := make([]byte, 16) + binary.BigEndian.PutUint64(statsBuf[0:8], docCount) + binary.BigEndian.PutUint64(statsBuf[8:16], totalTerms) + edge := &pb.DirectedEdge{ + Entity: 1, + Attr: attr, + Value: statsBuf, + ValueType: pb.Posting_ValType(0), + Op: pb.DirectedEdge_SET, + } + return plist.addMutation(ctx, txn, edge) +} + // countParams is sent to updateCount function. It is used to update the count index. // It deletes the uid from the key corresponding to and adds it // to . diff --git a/query/common_test.go b/query/common_test.go index e36211f7a18..32a3e65a81b 100644 --- a/query/common_test.go +++ b/query/common_test.go @@ -390,6 +390,11 @@ func populateCluster(dc dgraphapi.Cluster) { testSchema += "\ndescription: string @index(ngram) ." } + // BM25 indexing - uses same version gate as ngram for now + if ngramSupport { + testSchema += "\ndescription_bm25: string @index(bm25) ." + } + setSchema(testSchema) err = addTriplesToCluster(` @@ -1007,4 +1012,16 @@ func populateCluster(dc dgraphapi.Cluster) { <415> "Linguistic analysis helps understand text meaning" . `) x.Panic(err) + + // Add data for BM25 tests - uses separate predicate to avoid conflicts + err = addTriplesToCluster(` + <501> "The quick brown fox jumps over the lazy dog" . + <502> "A quick brown fox leaps over a sleeping dog" . + <503> "fox fox fox" . + <504> "The lazy dog sleeps under the warm sun all day long in the garden" . + <505> "Dogs are loyal companions to humans and families everywhere" . + <506> "Quick movements help foxes catch their prey in the wild" . + <507> "Brown foxes are quick and agile animals in the forest" . + `) + x.Panic(err) } diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go new file mode 100644 index 00000000000..f0a3a0c16a9 --- /dev/null +++ b/query/query_bm25_test.go @@ -0,0 +1,214 @@ +//go:build integration || cloud + +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +//nolint:lll +package query + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBM25Basic(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "quick brown fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should return documents containing "quick", "brown", or "fox" + require.Contains(t, js, "quick brown fox jumps") + require.Contains(t, js, "quick brown fox leaps") +} + +func TestBM25Ordering(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Document 503 has "fox fox fox" (tf=3, short doc) so should rank highest. + // Verify it appears before other fox-containing documents in the output. + foxFoxFoxIdx := strings.Index(js, "fox fox fox") + quickBrownIdx := strings.Index(js, "quick brown fox jumps") + require.Greater(t, foxFoxFoxIdx, -1, "should contain 'fox fox fox'") + require.Greater(t, quickBrownIdx, -1, "should contain 'quick brown fox jumps'") + require.Less(t, foxFoxFoxIdx, quickBrownIdx, + "'fox fox fox' (higher tf, shorter doc) should rank before 'quick brown fox jumps'") +} + +func TestBM25WithParams(t *testing.T) { + // Custom k and b parameters + query := ` + { + me(func: bm25(description_bm25, "fox", "1.5", "0.5")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "fox") +} + +func TestBM25InvalidParams(t *testing.T) { + // Negative k should be rejected. + query := ` + { + me(func: bm25(description_bm25, "fox", "-1.0", "0.75")) { + uid + } + } + ` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: k must be a positive finite number") + + // b > 1 should be rejected. + query2 := ` + { + me(func: bm25(description_bm25, "fox", "1.2", "1.5")) { + uid + } + } + ` + _, err = processQuery(context.Background(), t, query2) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: b must be between 0 and 1") + + // b < 0 should be rejected. + query3 := ` + { + me(func: bm25(description_bm25, "fox", "1.2", "-0.5")) { + uid + } + } + ` + _, err = processQuery(context.Background(), t, query3) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: b must be between 0 and 1") +} + +func TestBM25AsFilter(t *testing.T) { + query := ` + { + me(func: has(description_bm25)) @filter(bm25(description_bm25, "fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "fox") + // Should not contain documents without "fox" + require.NotContains(t, js, "Dogs are loyal") +} + +func TestBM25NoResults(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "xyznonexistent")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25SingleTerm(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "dog")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "dog") +} + +func TestBM25MultiTerm(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "quick lazy")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should find docs with "quick" or "lazy" (scores accumulate). + // Doc 501 has both "quick" and "lazy", so it should rank high. + require.Contains(t, js, "quick brown fox jumps over the lazy dog") +} + +func TestBM25AllStopwords(t *testing.T) { + // A query consisting entirely of stopwords should return no results. + query := ` + { + me(func: bm25(description_bm25, "the a an")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25EmptyPredicate(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "")) { + uid + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25WithCount(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "fox")) { + count(uid) + } + } + ` + js := processQueryNoErr(t, query) + // Should have at least 2 results (docs with "fox") + require.Contains(t, js, "count") +} + +func TestBM25Pagination(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // With first:1, should return exactly one result (the highest-scoring). + // Doc 503 "fox fox fox" should be the top result. + require.Contains(t, js, "fox fox fox") +} diff --git a/tok/tok.go b/tok/tok.go index c1da3e991d7..cb50b0a369e 100644 --- a/tok/tok.go +++ b/tok/tok.go @@ -50,6 +50,7 @@ const ( IdentBigFloat = 0xD IdentVFloat = 0xE IdentNGram = 0xF + IdentBM25 = 0x10 IdentCustom = 0x80 IdentDelimiter = 0x1f // ASCII 31 - Unit separator ) @@ -101,6 +102,7 @@ func init() { registerTokenizer(TermTokenizer{}) registerTokenizer(FullTextTokenizer{}) registerTokenizer(NGramTokenizer{}) + registerTokenizer(BM25Tokenizer{}) registerTokenizer(Sha256Tokenizer{}) setupBleve() } @@ -576,6 +578,47 @@ func (t FullTextTokenizer) Identifier() byte { return IdentFullText } func (t FullTextTokenizer) IsSortable() bool { return false } func (t FullTextTokenizer) IsLossy() bool { return true } +// BM25Tokenizer generates tokens for BM25 ranked text search. +// It uses the same pipeline as FullTextTokenizer (normalize, stopwords, stem) +// but preserves duplicates for term frequency counting. +type BM25Tokenizer struct{ lang string } + +func (t BM25Tokenizer) Name() string { return "bm25" } +func (t BM25Tokenizer) Type() string { return "string" } +func (t BM25Tokenizer) Tokens(v interface{}) ([]string, error) { + str, ok := v.(string) + if !ok || str == "" { + return []string{}, nil + } + lang := LangBase(t.lang) + tokens := fulltextAnalyzer.Analyze([]byte(str)) + tokens = filterStopwords(lang, tokens) + tokens = filterStemmers(lang, tokens) + // Return all tokens with duplicates preserved (for TF counting). + result := make([]string, 0, len(tokens)) + for _, t := range tokens { + result = append(result, string(t.Term)) + } + return result, nil +} +func (t BM25Tokenizer) Identifier() byte { return IdentBM25 } +func (t BM25Tokenizer) IsSortable() bool { return false } +func (t BM25Tokenizer) IsLossy() bool { return true } + +// TokensWithFrequency tokenizes the input and returns term frequencies and doc length. +func (t BM25Tokenizer) TokensWithFrequency(v interface{}, lang string) (map[string]uint32, uint32, error) { + tok := BM25Tokenizer{lang: lang} + allTokens, err := tok.Tokens(v) + if err != nil { + return nil, 0, err + } + termFreqs := make(map[string]uint32, len(allTokens)) + for _, t := range allTokens { + termFreqs[t]++ + } + return termFreqs, uint32(len(allTokens)), nil +} + // Sha256Tokenizer generates tokens for the sha256 hash part from string data. type Sha256Tokenizer struct{ _ string } diff --git a/tok/tok_test.go b/tok/tok_test.go index 4c95094e577..b9fbc4dd1a5 100644 --- a/tok/tok_test.go +++ b/tok/tok_test.go @@ -652,6 +652,146 @@ func TestNGramTokenizerNonStringInput(t *testing.T) { require.Equal(t, 0, len(tokens2), "Expected empty tokens for nil input") } +func TestBM25Tokenizer(t *testing.T) { + tokenizer, has := GetTokenizer("bm25") + require.True(t, has) + require.NotNil(t, tokenizer) + require.Equal(t, "bm25", tokenizer.Name()) + require.Equal(t, "string", tokenizer.Type()) + require.Equal(t, byte(IdentBM25), tokenizer.Identifier()) + require.True(t, tokenizer.IsLossy()) + require.False(t, tokenizer.IsSortable()) +} + +func TestBM25TokensPreservesDuplicates(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("fox fox fox dog") + require.NoError(t, err) + // "fox" should appear 3 times (duplicates preserved), "dog" once + foxCount := 0 + dogCount := 0 + for _, token := range tokens { + if token == "fox" { + foxCount++ + } + if token == "dog" { + dogCount++ + } + } + require.Equal(t, 3, foxCount, "Expected 3 occurrences of 'fox'") + require.Equal(t, 1, dogCount, "Expected 1 occurrence of 'dog'") +} + +func TestBM25TokensWithFrequency(t *testing.T) { + tok := BM25Tokenizer{} + termFreqs, docLen, err := tok.TokensWithFrequency("the quick brown fox fox fox", "en") + require.NoError(t, err) + // "the" is a stopword and should be removed + _, hasThe := termFreqs["the"] + require.False(t, hasThe, "'the' should be removed as stopword") + // "fox" should have tf=3 + require.Equal(t, uint32(3), termFreqs["fox"]) + // "quick" -> "quick" (stemmed) + require.Contains(t, termFreqs, "quick") + require.Equal(t, uint32(1), termFreqs["quick"]) + // "brown" -> "brown" (stemmed) + require.Contains(t, termFreqs, "brown") + require.Equal(t, uint32(1), termFreqs["brown"]) + // docLen should be total tokens after stopword removal + require.Equal(t, uint32(5), docLen) +} + +func TestBM25TokensEmpty(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) + + termFreqs, docLen, err := tok.TokensWithFrequency("", "en") + require.NoError(t, err) + require.Equal(t, 0, len(termFreqs)) + require.Equal(t, uint32(0), docLen) +} + +func TestBM25TokensSingleWord(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("hello") + require.NoError(t, err) + require.Equal(t, 1, len(tokens)) + require.Equal(t, "hello", tokens[0]) +} + +func TestBM25TokensStemming(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("running jumping swimming") + require.NoError(t, err) + require.Equal(t, 3, len(tokens)) + require.Contains(t, tokens, "run") + require.Contains(t, tokens, "jump") + require.Contains(t, tokens, "swim") +} + +func TestGetBM25QueryTokens(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{"quick brown fox fox"}, "en") + require.NoError(t, err) + // Query tokens should be deduplicated + require.Equal(t, 3, len(tokens)) + // Each token should be encoded with the BM25 identifier prefix + for _, token := range tokens { + require.Equal(t, byte(IdentBM25), token[0], "Token should start with BM25 identifier") + } +} + +func TestGetBM25QueryTokensEmpty(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{""}, "en") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) +} + +func TestBM25TokenizerForLang(t *testing.T) { + tokenizer, has := GetTokenizer("bm25") + require.True(t, has) + langTok := GetTokenizerForLang(tokenizer, "de") + bm25Tok, ok := langTok.(BM25Tokenizer) + require.True(t, ok) + // German: "Katzen" -> "katz" (stemmed) + tokens, err := bm25Tok.Tokens("Katzen und Katzen") + require.NoError(t, err) + // "und" is a German stopword + katzCount := 0 + for _, token := range tokens { + if token == "katz" { + katzCount++ + } + } + require.Equal(t, 2, katzCount, "Expected 2 occurrences of stemmed 'katz'") +} + +func TestBM25AllStopwords(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("the a an is") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) + + termFreqs, docLen, err := tok.TokensWithFrequency("the a an is", "en") + require.NoError(t, err) + require.Equal(t, 0, len(termFreqs)) + require.Equal(t, uint32(0), docLen) +} + +func TestGetBM25QueryTokensAllStopwords(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{"the a an"}, "en") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) +} + +func TestGetBM25QueryTokensWrongArgCount(t *testing.T) { + _, err := GetBM25QueryTokens([]string{}, "en") + require.Error(t, err) + _, err = GetBM25QueryTokens([]string{"a", "b"}, "en") + require.Error(t, err) +} + func BenchmarkTermTokenizer(b *testing.B) { b.Skip() // tmp } diff --git a/tok/tokens.go b/tok/tokens.go index bda9a04e743..f089a3f4344 100644 --- a/tok/tokens.go +++ b/tok/tokens.go @@ -25,6 +25,8 @@ func GetTokenizerForLang(t Tokenizer, lang string) Tokenizer { // We must return a new instance because another goroutine might be calling this // with a different lang. return FullTextTokenizer{lang: lang} + case BM25Tokenizer: + return BM25Tokenizer{lang: lang} case TermTokenizer: return TermTokenizer{lang: lang} case ExactTokenizer: @@ -67,6 +69,29 @@ func GetNGramQueryTokens(funcArgs []string, lang string) ([]string, error) { return BuildNGramQueryTokens(funcArgs[0], NGramTokenizer{lang: lang}) } +// GetBM25QueryTokens tokenizes the query text using the fulltext pipeline, +// deduplicates, and encodes with the BM25 identifier prefix. +func GetBM25QueryTokens(funcArgs []string, lang string) ([]string, error) { + if l := len(funcArgs); l != 1 { + return nil, errors.Errorf("Function requires 1 arguments, but got %d", l) + } + tok := BM25Tokenizer{lang: lang} + allTokens, err := tok.Tokens(funcArgs[0]) + if err != nil { + return nil, err + } + // Deduplicate for query + seen := make(map[string]struct{}, len(allTokens)) + var unique []string + for _, t := range allTokens { + if _, ok := seen[t]; !ok { + seen[t] = struct{}{} + unique = append(unique, encodeToken(t, tok.Identifier())) + } + } + return unique, nil +} + // GetFullTextTokens returns the full-text tokens for the given value. func GetFullTextTokens(funcArgs []string, lang string) ([]string, error) { if l := len(funcArgs); l != 1 { diff --git a/worker/task.go b/worker/task.go index 409ec3f0fc4..1da128c76bc 100644 --- a/worker/task.go +++ b/worker/task.go @@ -7,6 +7,7 @@ package worker import ( "context" + "encoding/binary" "fmt" "math" "sort" @@ -224,6 +225,7 @@ const ( customIndexFn matchFn similarToFn + bm25SearchFn standardFn = 100 ) @@ -266,6 +268,8 @@ func parseFuncTypeHelper(name string) (FuncType, string) { return uidInFn, f case "similar_to": return similarToFn, f + case "bm25": + return bm25SearchFn, f case "anyof", "allof": return customIndexFn, f case "match": @@ -292,6 +296,8 @@ func needsIndex(fnType FuncType, uidList *pb.List) bool { return true case similarToFn: return true + case bm25SearchFn: + return true } return false } @@ -314,7 +320,7 @@ type funcArgs struct { // The function tells us whether we want to fetch value posting lists or uid posting lists. func (srcFn *functionContext) needsValuePostings(typ types.TypeID) (bool, error) { switch srcFn.fnType { - case aggregatorFn, passwordFn, similarToFn: + case aggregatorFn, passwordFn, similarToFn, bm25SearchFn: return true, nil case compareAttrFn: if len(srcFn.tokens) > 0 { @@ -351,11 +357,15 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er attribute.String("srcFn", x.SafeUTF8(fmt.Sprintf("%+v", args.srcFn))))) switch srcFn.fnType { - case notAFunction, aggregatorFn, passwordFn, compareAttrFn, similarToFn: + case notAFunction, aggregatorFn, passwordFn, compareAttrFn, similarToFn, bm25SearchFn: default: return errors.Errorf("Unhandled function in handleValuePostings: %s", srcFn.fname) } + if srcFn.fnType == bm25SearchFn { + return qs.handleBM25Search(ctx, args) + } + if srcFn.fnType == similarToFn { numNeighbors, err := strconv.ParseInt(q.SrcFunc.Args[0], 10, 32) if err != nil { @@ -1219,6 +1229,196 @@ func needsStringFiltering(srcFn *functionContext, langs []string, attr string) b srcFn.fnType == customIndexFn || srcFn.fnType == ngramFn) } +func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error { + q := args.q + attr := q.Attr + + // 1. Parse args: query text, optional k (default 1.2), b (default 0.75). + if len(q.SrcFunc.Args) < 1 { + return errors.Errorf("bm25 requires at least 1 argument (query text)") + } + queryText := q.SrcFunc.Args[0] + k := 1.2 + b := 0.75 + if len(q.SrcFunc.Args) >= 2 { + var err error + k, err = strconv.ParseFloat(q.SrcFunc.Args[1], 64) + if err != nil { + return errors.Errorf("bm25: invalid k parameter: %s", q.SrcFunc.Args[1]) + } + } + if len(q.SrcFunc.Args) >= 3 { + var err error + b, err = strconv.ParseFloat(q.SrcFunc.Args[2], 64) + if err != nil { + return errors.Errorf("bm25: invalid b parameter: %s", q.SrcFunc.Args[2]) + } + } + if math.IsNaN(k) || math.IsInf(k, 0) || k <= 0 { + return errors.Errorf("bm25: k must be a positive finite number, got %v", k) + } + if math.IsNaN(b) || math.IsInf(b, 0) || b < 0 || b > 1 { + return errors.Errorf("bm25: b must be between 0 and 1, got %v", b) + } + + // 2. Tokenize query (deduplicated) using fulltext pipeline. + lang := langForFunc(q.Langs) + queryTokens, err := tok.GetBM25QueryTokens([]string{queryText}, lang) + if err != nil { + return err + } + if len(queryTokens) == 0 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + + // 3. Read corpus stats. + statsKey := x.BM25StatsKey(attr) + statsPl, err := qs.cache.Get(statsKey) + if err != nil { + // No stats means no documents indexed yet. + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + statsVal, err := statsPl.Value(q.ReadTs) + if err != nil || statsVal.Value == nil { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + statsData, ok := statsVal.Value.([]byte) + if !ok || len(statsData) != 16 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + docCount := binary.BigEndian.Uint64(statsData[0:8]) + totalTerms := binary.BigEndian.Uint64(statsData[8:16]) + if docCount == 0 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + if totalTerms == 0 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + avgDL := float64(totalTerms) / float64(docCount) + N := float64(docCount) + + // Build filter set early if used as a filter, for efficient intersection during iteration. + var filterSet map[uint64]struct{} + if q.UidList != nil && len(q.UidList.Uids) > 0 { + filterSet = make(map[uint64]struct{}, len(q.UidList.Uids)) + for _, uid := range q.UidList.Uids { + filterSet[uid] = struct{}{} + } + } + + // 4. For each query token, read the posting list and collect term info. + type termInfo struct { + idf float64 + uidTFs map[uint64]uint32 + } + termInfos := make(map[string]*termInfo) + + for _, token := range queryTokens { + key := x.BM25IndexKey(attr, token) + pl, err := qs.cache.Get(key) + if err != nil { + continue + } + + ti := &termInfo{uidTFs: make(map[uint64]uint32)} + var df float64 + err = pl.Iterate(q.ReadTs, 0, func(p *pb.Posting) error { + df++ + // When used as filter, only collect TF for UIDs in the filter set. + if filterSet != nil { + if _, ok := filterSet[p.Uid]; !ok { + return nil + } + } + tf := uint32(1) + if len(p.Value) >= 4 { + tf = binary.BigEndian.Uint32(p.Value[:4]) + } + ti.uidTFs[p.Uid] = tf + return nil + }) + if err != nil { + continue + } + ti.idf = math.Log1p((N - df + 0.5) / (df + 0.5)) + termInfos[token] = ti + } + + // 5. Read doc lengths for all UIDs seen. + allUids := make(map[uint64]struct{}) + for _, ti := range termInfos { + for uid := range ti.uidTFs { + allUids[uid] = struct{}{} + } + } + + docLens := make(map[uint64]uint32) + dlKey := x.BM25DocLenKey(attr) + dlPl, err := qs.cache.Get(dlKey) + if err == nil { + remaining := len(allUids) + _ = dlPl.Iterate(q.ReadTs, 0, func(p *pb.Posting) error { + if remaining == 0 { + return posting.ErrStopIteration + } + if _, needed := allUids[p.Uid]; needed { + dl := uint32(1) + if len(p.Value) >= 4 { + dl = binary.BigEndian.Uint32(p.Value[:4]) + } + docLens[p.Uid] = dl + remaining-- + } + return nil + }) + } + + // 6. Compute final BM25 scores. + scores := make(map[uint64]float64) + for _, ti := range termInfos { + for uid, tf := range ti.uidTFs { + dl := float64(1) + if v, ok := docLens[uid]; ok { + dl = float64(v) + } + tfFloat := float64(tf) + score := ti.idf * (k + 1) * tfFloat / (k*(1-b+b*dl/avgDL) + tfFloat) + scores[uid] += score + } + } + + // 7. Sort by score descending. + type uidScore struct { + uid uint64 + score float64 + } + results := make([]uidScore, 0, len(scores)) + for uid, score := range scores { + results = append(results, uidScore{uid: uid, score: score}) + } + sort.Slice(results, func(i, j int) bool { + if results[i].score != results[j].score { + return results[i].score > results[j].score + } + return results[i].uid < results[j].uid + }) + + // Build output UIDs. + uids := make([]uint64, len(results)) + for i, r := range results { + uids[i] = r.uid + } + + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) + return nil +} + func (qs *queryState) handleCompareScalarFunction(ctx context.Context, arg funcArgs) error { attr := arg.q.Attr if ok := schema.State().HasCount(ctx, attr); !ok { @@ -2167,6 +2367,18 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { return nil, err } checkRoot(q, fc) + case bm25SearchFn: + // bm25(pred, "query text") or bm25(pred, "query text", "k", "b") + if len(q.SrcFunc.Args) < 1 || len(q.SrcFunc.Args) > 3 { + return nil, errors.Errorf("Function 'bm25' requires 1-3 arguments (query [, k, b]), but got %d", + len(q.SrcFunc.Args)) + } + required, found := verifyStringIndex(ctx, attr, fnType) + if !found { + return nil, errors.Errorf("Attribute %s is not indexed with type %s", x.ParseAttr(attr), + required) + } + checkRoot(q, fc) case similarToFn: // similar_to accepts 2 mandatory args: k, vector_or_uid followed by optional key:value pairs // Example: similar_to(vpred, 3, $vec, ef: 64, distance_threshold: 0.5) diff --git a/worker/tokens.go b/worker/tokens.go index 2740d29f447..b8c85a22816 100644 --- a/worker/tokens.go +++ b/worker/tokens.go @@ -25,6 +25,8 @@ func verifyStringIndex(ctx context.Context, attr string, funcType FuncType) (str requiredTokenizer = tok.NGramTokenizer{} case fullTextSearchFn: requiredTokenizer = tok.FullTextTokenizer{} + case bm25SearchFn: + requiredTokenizer = tok.BM25Tokenizer{} case matchFn: requiredTokenizer = tok.TrigramTokenizer{} default: @@ -65,6 +67,9 @@ func getStringTokens(funcArgs []string, lang string, funcType FuncType, query bo if funcType == fullTextSearchFn { return tok.GetFullTextTokens(funcArgs, lang) } + if funcType == bm25SearchFn { + return tok.GetBM25QueryTokens(funcArgs, lang) + } if funcType == ngramFn { if query { return tok.GetNGramQueryTokens(funcArgs, lang) diff --git a/x/keys.go b/x/keys.go index 94112d07c03..23196fd89c9 100644 --- a/x/keys.go +++ b/x/keys.go @@ -291,6 +291,25 @@ func CountKey(attr string, count uint32, reverse bool) []byte { return buf } +// BM25Prefix is the prefix used for BM25 index keys to prevent collision +// with regular fulltext index tokens. +const BM25Prefix = "\x00_bm25_" + +// BM25IndexKey generates an index key for a BM25 term posting list. +func BM25IndexKey(attr string, token string) []byte { + return IndexKey(attr, BM25Prefix+token) +} + +// BM25DocLenKey generates the key for the BM25 document length posting list. +func BM25DocLenKey(attr string) []byte { + return IndexKey(attr, BM25Prefix+"__doclen__") +} + +// BM25StatsKey generates the key for BM25 corpus statistics. +func BM25StatsKey(attr string) []byte { + return IndexKey(attr, BM25Prefix+"__stats__") +} + // ParsedKey represents a key that has been parsed into its multiple attributes. type ParsedKey struct { Attr string From 937da2ec60f39efea21eb096fa77dc34201f81d0 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 4 Mar 2026 17:18:15 -0500 Subject: [PATCH 02/12] fix(bm25): store TF/doclen in facets and fix query pipeline integration Three critical bugs fixed: 1. REF postings lose Value during rollup: The posting list encode/rollup cycle strips the Value field from REF postings without facets (list.go:1630). BM25 term frequencies and doc lengths were stored in Value and lost. Fix: Store TF and doclen as facets on REF postings, which are preserved. 2. Missing function validation: query/query.go has a separate isValidFuncName check from dql/parser.go. "bm25" was only added to the parser, causing "Invalid function name: bm25" at query time. 3. Unsorted UIDs break query pipeline: BM25 returned UIDs sorted by score, but the query pipeline (algo.MergeSorted, child predicate fetching) requires UID-ascending order. Fix: Sort UIDs ascending in UidMatrix, apply first/offset pagination on score-sorted results before UID sorting. Co-Authored-By: Claude Opus 4.6 --- posting/index.go | 22 +++++++++++----------- query/query.go | 2 +- query/query_bm25_test.go | 27 ++++++++++++++++++--------- worker/task.go | 30 +++++++++++++++++++++++++----- 4 files changed, 55 insertions(+), 26 deletions(-) diff --git a/posting/index.go b/posting/index.go index 88c0e5920a9..a24f0bac2e6 100644 --- a/posting/index.go +++ b/posting/index.go @@ -28,6 +28,7 @@ import ( "github.com/dgraph-io/badger/v4" "github.com/dgraph-io/badger/v4/options" bpb "github.com/dgraph-io/badger/v4/pb" + "github.com/dgraph-io/dgo/v250/protos/api" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" "github.com/dgraph-io/dgraph/v25/tok" @@ -304,15 +305,15 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn if err != nil { return err } - // Store uid in the posting list. The TF is encoded in the Value field. + // Store uid in the posting list. TF is stored as a facet so it survives + // the rollup cycle (REF postings without facets lose their Value field). tfBuf := make([]byte, 4) binary.BigEndian.PutUint32(tfBuf, tf) edge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Value: tfBuf, - ValueType: pb.Posting_INT, - Op: pb.DirectedEdge_SET, + ValueId: uid, + Attr: attr, + Op: pb.DirectedEdge_SET, + Facets: []*api.Facet{{Key: "tf", Value: tfBuf, ValType: api.Facet_INT}}, } if err := plist.addMutation(ctx, txn, edge); err != nil { return err @@ -328,11 +329,10 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn dlBuf := make([]byte, 4) binary.BigEndian.PutUint32(dlBuf, docLen) dlEdge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Value: dlBuf, - ValueType: pb.Posting_INT, - Op: pb.DirectedEdge_SET, + ValueId: uid, + Attr: attr, + Op: pb.DirectedEdge_SET, + Facets: []*api.Facet{{Key: "dl", Value: dlBuf, ValType: api.Facet_INT}}, } if err := dlPlist.addMutation(ctx, txn, dlEdge); err != nil { return err diff --git a/query/query.go b/query/query.go index 6926e2ac6ed..3025033e1e0 100644 --- a/query/query.go +++ b/query/query.go @@ -2751,7 +2751,7 @@ func isValidArg(a string) bool { func isValidFuncName(f string) bool { switch f { case "anyofterms", "allofterms", "val", "regexp", "anyoftext", "alloftext", "ngram", - "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25": return true } return isInequalityFn(f) || types.IsGeoFunc(f) diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index f0a3a0c16a9..dcceece7428 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -10,7 +10,6 @@ package query import ( "context" - "strings" "testing" "github.com/stretchr/testify/require" @@ -32,6 +31,8 @@ func TestBM25Basic(t *testing.T) { } func TestBM25Ordering(t *testing.T) { + // BM25 returns all matching documents. Use first:1 to verify the highest-scored + // document is "fox fox fox" (tf=3, short doc). query := ` { me(func: bm25(description_bm25, "fox")) { @@ -41,14 +42,22 @@ func TestBM25Ordering(t *testing.T) { } ` js := processQueryNoErr(t, query) - // Document 503 has "fox fox fox" (tf=3, short doc) so should rank highest. - // Verify it appears before other fox-containing documents in the output. - foxFoxFoxIdx := strings.Index(js, "fox fox fox") - quickBrownIdx := strings.Index(js, "quick brown fox jumps") - require.Greater(t, foxFoxFoxIdx, -1, "should contain 'fox fox fox'") - require.Greater(t, quickBrownIdx, -1, "should contain 'quick brown fox jumps'") - require.Less(t, foxFoxFoxIdx, quickBrownIdx, - "'fox fox fox' (higher tf, shorter doc) should rank before 'quick brown fox jumps'") + // Should contain all fox-mentioning documents. + require.Contains(t, js, "fox fox fox") + require.Contains(t, js, "quick brown fox jumps") + + // first:1 should return the top-ranked document. + topQuery := ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + description_bm25 + } + } + ` + topJs := processQueryNoErr(t, topQuery) + require.Contains(t, topJs, "fox fox fox", + "top-1 BM25 result for 'fox' should be 'fox fox fox' (highest tf, shortest doc)") } func TestBM25WithParams(t *testing.T) { diff --git a/worker/task.go b/worker/task.go index 1da128c76bc..2fbd65acca4 100644 --- a/worker/task.go +++ b/worker/task.go @@ -1337,8 +1337,11 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } tf := uint32(1) - if len(p.Value) >= 4 { - tf = binary.BigEndian.Uint32(p.Value[:4]) + for _, f := range p.Facets { + if f.Key == "tf" && len(f.Value) >= 4 { + tf = binary.BigEndian.Uint32(f.Value[:4]) + break + } } ti.uidTFs[p.Uid] = tf return nil @@ -1369,8 +1372,11 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } if _, needed := allUids[p.Uid]; needed { dl := uint32(1) - if len(p.Value) >= 4 { - dl = binary.BigEndian.Uint32(p.Value[:4]) + for _, f := range p.Facets { + if f.Key == "dl" && len(f.Value) >= 4 { + dl = binary.BigEndian.Uint32(f.Value[:4]) + break + } } docLens[p.Uid] = dl remaining-- @@ -1402,6 +1408,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error for uid, score := range scores { results = append(results, uidScore{uid: uid, score: score}) } + // Sort by score descending for ordering, then collect UIDs. sort.Slice(results, func(i, j int) bool { if results[i].score != results[j].score { return results[i].score > results[j].score @@ -1409,11 +1416,24 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error return results[i].uid < results[j].uid }) - // Build output UIDs. + // Apply first/offset pagination on score-sorted results before returning UIDs. + if q.First > 0 || q.Offset > 0 { + offset := int(q.Offset) + if offset > len(results) { + offset = len(results) + } + results = results[offset:] + if q.First > 0 && int(q.First) < len(results) { + results = results[:int(q.First)] + } + } + + // Build output UIDs sorted by UID (ascending) as required by the query pipeline. uids := make([]uint64, len(results)) for i, r := range results { uids[i] = r.uid } + sort.Slice(uids, func(i, j int) bool { return uids[i] < uids[j] }) args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) return nil From a61b2268df21a685bb2e9f60ae73e3b0d4fee51f Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 4 Mar 2026 23:04:08 -0500 Subject: [PATCH 03/12] perf(bm25): replace facet storage with compact direct Badger KV encoding Replace the facet-based BM25 storage (~40-50 bytes/posting) with compact varint-encoded binary blobs stored as direct Badger KV entries (~4-6 bytes/posting, ~10x reduction). Add bm25_score pseudo-predicate for variable-based score ordering following the similar_to pattern. - Add posting/bm25enc package for compact binary encode/decode - Rewrite write path in posting/index.go for direct Badger KV - Add bm25Writes buffer to LocalCache with read-your-own-writes - Flush BM25 blobs in CommitToDisk with BitBM25Data UserMeta - Rewrite read path in worker/task.go with direct blob decoding - Add bm25_score pseudo-predicate in query/query.go - Add score ordering integration tests Co-Authored-By: Claude Opus 4.6 --- posting/bm25enc/bm25enc.go | 147 ++++++++++++++++++++++++++++++++ posting/bm25enc/bm25enc_test.go | 132 ++++++++++++++++++++++++++++ posting/index.go | 121 +++++++------------------- posting/list.go | 2 + posting/lists.go | 56 ++++++++++++ posting/mvcc.go | 15 ++++ query/query.go | 64 ++++++++++++++ query/query_bm25_test.go | 79 +++++++++++++++++ worker/task.go | 109 +++++++++-------------- 9 files changed, 562 insertions(+), 163 deletions(-) create mode 100644 posting/bm25enc/bm25enc.go create mode 100644 posting/bm25enc/bm25enc_test.go diff --git a/posting/bm25enc/bm25enc.go b/posting/bm25enc/bm25enc.go new file mode 100644 index 00000000000..8da82b299dd --- /dev/null +++ b/posting/bm25enc/bm25enc.go @@ -0,0 +1,147 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Package bm25enc provides compact binary encoding for BM25 index data. +// +// Two types of lists share the same format: +// - Term posting lists: (UID, term-frequency) pairs +// - Document length lists: (UID, doc-length) pairs +// +// Binary format: +// +// Header: +// [4 bytes] uint32 big-endian: entry count +// Entries (sorted ascending by UID): +// [varint] UID delta from previous (first entry is absolute) +// [varint] value (TF or doclen) +package bm25enc + +import ( + "encoding/binary" + "sort" +) + +// Entry represents a single (UID, Value) pair in a BM25 posting list. +type Entry struct { + UID uint64 + Value uint32 +} + +// Encode encodes a sorted slice of entries into the compact binary format. +// Entries must be sorted by UID ascending. Returns nil for empty input. +func Encode(entries []Entry) []byte { + if len(entries) == 0 { + return nil + } + + // Pre-allocate: 4 header + ~6 bytes per entry is a reasonable estimate. + buf := make([]byte, 4, 4+len(entries)*6) + binary.BigEndian.PutUint32(buf, uint32(len(entries))) + + var tmp [binary.MaxVarintLen64]byte + var prevUID uint64 + for _, e := range entries { + delta := e.UID - prevUID + n := binary.PutUvarint(tmp[:], delta) + buf = append(buf, tmp[:n]...) + n = binary.PutUvarint(tmp[:], uint64(e.Value)) + buf = append(buf, tmp[:n]...) + prevUID = e.UID + } + return buf +} + +// Decode decodes the binary format into a sorted slice of entries. +// Returns nil for nil/empty input. +func Decode(data []byte) []Entry { + if len(data) < 4 { + return nil + } + count := binary.BigEndian.Uint32(data[:4]) + if count == 0 { + return nil + } + + entries := make([]Entry, 0, count) + pos := 4 + var prevUID uint64 + for i := uint32(0); i < count; i++ { + delta, n := binary.Uvarint(data[pos:]) + if n <= 0 { + break + } + pos += n + + val, n := binary.Uvarint(data[pos:]) + if n <= 0 { + break + } + pos += n + + uid := prevUID + delta + entries = append(entries, Entry{UID: uid, Value: uint32(val)}) + prevUID = uid + } + return entries +} + +// Upsert inserts or updates the entry for uid in a sorted entries slice. +// Returns the new sorted slice. +func Upsert(entries []Entry, uid uint64, value uint32) []Entry { + i := sort.Search(len(entries), func(i int) bool { return entries[i].UID >= uid }) + if i < len(entries) && entries[i].UID == uid { + entries[i].Value = value + return entries + } + // Insert at position i. + entries = append(entries, Entry{}) + copy(entries[i+1:], entries[i:]) + entries[i] = Entry{UID: uid, Value: value} + return entries +} + +// Remove removes the entry for uid from a sorted entries slice. +// Returns the new slice (may be shorter). +func Remove(entries []Entry, uid uint64) []Entry { + i := sort.Search(len(entries), func(i int) bool { return entries[i].UID >= uid }) + if i < len(entries) && entries[i].UID == uid { + return append(entries[:i], entries[i+1:]...) + } + return entries +} + +// Search returns the value for uid using binary search, and whether it was found. +func Search(entries []Entry, uid uint64) (uint32, bool) { + i := sort.Search(len(entries), func(i int) bool { return entries[i].UID >= uid }) + if i < len(entries) && entries[i].UID == uid { + return entries[i].Value, true + } + return 0, false +} + +// UIDs extracts just the UIDs from entries as a uint64 slice. +func UIDs(entries []Entry) []uint64 { + uids := make([]uint64, len(entries)) + for i, e := range entries { + uids[i] = e.UID + } + return uids +} + +// EncodeStats encodes BM25 corpus statistics (docCount, totalTerms) as 16 bytes. +func EncodeStats(docCount, totalTerms uint64) []byte { + buf := make([]byte, 16) + binary.BigEndian.PutUint64(buf[0:8], docCount) + binary.BigEndian.PutUint64(buf[8:16], totalTerms) + return buf +} + +// DecodeStats decodes BM25 corpus statistics. Returns (0,0) for invalid input. +func DecodeStats(data []byte) (docCount, totalTerms uint64) { + if len(data) != 16 { + return 0, 0 + } + return binary.BigEndian.Uint64(data[0:8]), binary.BigEndian.Uint64(data[8:16]) +} diff --git a/posting/bm25enc/bm25enc_test.go b/posting/bm25enc/bm25enc_test.go new file mode 100644 index 00000000000..1969e472ed2 --- /dev/null +++ b/posting/bm25enc/bm25enc_test.go @@ -0,0 +1,132 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package bm25enc + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRoundtrip(t *testing.T) { + entries := []Entry{ + {UID: 1, Value: 3}, + {UID: 5, Value: 1}, + {UID: 100, Value: 7}, + {UID: 200, Value: 2}, + } + data := Encode(entries) + got := Decode(data) + require.Equal(t, entries, got) +} + +func TestRoundtripEmpty(t *testing.T) { + require.Nil(t, Encode(nil)) + require.Nil(t, Encode([]Entry{})) + require.Nil(t, Decode(nil)) + require.Nil(t, Decode([]byte{})) + require.Nil(t, Decode([]byte{0, 0, 0, 0})) // count=0 +} + +func TestRoundtripSingle(t *testing.T) { + entries := []Entry{{UID: 42, Value: 10}} + got := Decode(Encode(entries)) + require.Equal(t, entries, got) +} + +func TestRoundtripLargeUIDs(t *testing.T) { + entries := []Entry{ + {UID: 1<<40 + 1, Value: 1}, + {UID: 1<<40 + 1000, Value: 5}, + {UID: 1<<50 + 999, Value: 99}, + } + got := Decode(Encode(entries)) + require.Equal(t, entries, got) +} + +func TestUpsertNew(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}} + entries = Upsert(entries, 3, 7) + require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 3, Value: 7}, {UID: 5, Value: 1}}, entries) +} + +func TestUpsertExisting(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}} + entries = Upsert(entries, 5, 99) + require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 99}}, entries) +} + +func TestUpsertEmpty(t *testing.T) { + var entries []Entry + entries = Upsert(entries, 10, 5) + require.Equal(t, []Entry{{UID: 10, Value: 5}}, entries) +} + +func TestRemove(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}, {UID: 10, Value: 2}} + entries = Remove(entries, 5) + require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 10, Value: 2}}, entries) +} + +func TestRemoveNotFound(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}} + entries = Remove(entries, 99) + require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}}, entries) +} + +func TestSearch(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}, {UID: 100, Value: 7}} + v, ok := Search(entries, 5) + require.True(t, ok) + require.Equal(t, uint32(1), v) + + _, ok = Search(entries, 50) + require.False(t, ok) +} + +func TestUIDs(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}, {UID: 100, Value: 7}} + require.Equal(t, []uint64{1, 5, 100}, UIDs(entries)) +} + +func TestStatsRoundtrip(t *testing.T) { + data := EncodeStats(12345, 98765) + dc, tt := DecodeStats(data) + require.Equal(t, uint64(12345), dc) + require.Equal(t, uint64(98765), tt) +} + +func TestStatsInvalid(t *testing.T) { + dc, tt := DecodeStats(nil) + require.Zero(t, dc) + require.Zero(t, tt) + dc, tt = DecodeStats([]byte{1, 2, 3}) + require.Zero(t, dc) + require.Zero(t, tt) +} + +func BenchmarkEncode(b *testing.B) { + entries := make([]Entry, 10000) + for i := range entries { + entries[i] = Entry{UID: uint64(i*3 + 1), Value: uint32(i % 100)} + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + Encode(entries) + } +} + +func BenchmarkDecode(b *testing.B) { + entries := make([]Entry, 10000) + for i := range entries { + entries[i] = Entry{UID: uint64(i*3 + 1), Value: uint32(i % 100)} + } + data := Encode(entries) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Decode(data) + } +} diff --git a/posting/index.go b/posting/index.go index a24f0bac2e6..826355a3633 100644 --- a/posting/index.go +++ b/posting/index.go @@ -28,7 +28,7 @@ import ( "github.com/dgraph-io/badger/v4" "github.com/dgraph-io/badger/v4/options" bpb "github.com/dgraph-io/badger/v4/pb" - "github.com/dgraph-io/dgo/v250/protos/api" + "github.com/dgraph-io/dgraph/v25/posting/bm25enc" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" "github.com/dgraph-io/dgraph/v25/tok" @@ -232,7 +232,8 @@ func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, tok } // addBM25IndexMutations handles index mutations for the BM25 tokenizer. -// It stores term frequencies, document lengths, and corpus statistics. +// It stores term frequencies, document lengths, and corpus statistics as direct +// Badger KV entries using compact varint encoding, bypassing posting lists. func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationInfo) error { attr := info.edge.Attr uid := info.edge.Entity @@ -260,107 +261,53 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn } if info.op == pb.DirectedEdge_DEL { - // For DELETE: remove uid from all BM25 term posting lists, doc length list, - // and decrement corpus stats. + // For DELETE: remove uid from all BM25 term posting lists and doc length list. for term := range termFreqs { encodedTerm := string([]byte{tok.IdentBM25}) + term key := x.BM25IndexKey(attr, encodedTerm) - plist, err := txn.cache.GetFromDelta(key) - if err != nil { - return err - } - edge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Op: pb.DirectedEdge_DEL, - } - if err := plist.addMutation(ctx, txn, edge); err != nil { - return err - } + blob := txn.cache.ReadBM25Blob(key) + entries := bm25enc.Decode(blob) + entries = bm25enc.Remove(entries, uid) + txn.cache.WriteBM25Blob(key, bm25enc.Encode(entries)) } // Remove doc length entry. dlKey := x.BM25DocLenKey(attr) - dlPlist, err := txn.cache.GetFromDelta(dlKey) - if err != nil { - return err - } - dlEdge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Op: pb.DirectedEdge_DEL, - } - if err := dlPlist.addMutation(ctx, txn, dlEdge); err != nil { - return err - } + blob := txn.cache.ReadBM25Blob(dlKey) + entries := bm25enc.Decode(blob) + entries = bm25enc.Remove(entries, uid) + txn.cache.WriteBM25Blob(dlKey, bm25enc.Encode(entries)) // Update corpus stats: decrement doc count and total terms. - return txn.updateBM25Stats(ctx, attr, -1, -int64(docLen)) + return txn.updateBM25Stats(attr, -1, -int64(docLen)) } - // For SET: store term frequencies, doc length, and update corpus stats. + // For SET: store term frequencies and doc length. for term, tf := range termFreqs { encodedTerm := string([]byte{tok.IdentBM25}) + term key := x.BM25IndexKey(attr, encodedTerm) - plist, err := txn.cache.GetFromDelta(key) - if err != nil { - return err - } - // Store uid in the posting list. TF is stored as a facet so it survives - // the rollup cycle (REF postings without facets lose their Value field). - tfBuf := make([]byte, 4) - binary.BigEndian.PutUint32(tfBuf, tf) - edge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Op: pb.DirectedEdge_SET, - Facets: []*api.Facet{{Key: "tf", Value: tfBuf, ValType: api.Facet_INT}}, - } - if err := plist.addMutation(ctx, txn, edge); err != nil { - return err - } + blob := txn.cache.ReadBM25Blob(key) + entries := bm25enc.Decode(blob) + entries = bm25enc.Upsert(entries, uid, tf) + txn.cache.WriteBM25Blob(key, bm25enc.Encode(entries)) } // Store document length. dlKey := x.BM25DocLenKey(attr) - dlPlist, err := txn.cache.GetFromDelta(dlKey) - if err != nil { - return err - } - dlBuf := make([]byte, 4) - binary.BigEndian.PutUint32(dlBuf, docLen) - dlEdge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Op: pb.DirectedEdge_SET, - Facets: []*api.Facet{{Key: "dl", Value: dlBuf, ValType: api.Facet_INT}}, - } - if err := dlPlist.addMutation(ctx, txn, dlEdge); err != nil { - return err - } + blob := txn.cache.ReadBM25Blob(dlKey) + entries := bm25enc.Decode(blob) + entries = bm25enc.Upsert(entries, uid, docLen) + txn.cache.WriteBM25Blob(dlKey, bm25enc.Encode(entries)) // Update corpus stats: increment doc count by 1 and total terms by docLen. - return txn.updateBM25Stats(ctx, attr, 1, int64(docLen)) + return txn.updateBM25Stats(attr, 1, int64(docLen)) } // updateBM25Stats reads the current corpus statistics for a BM25-indexed attribute, -// applies the given deltas, and writes back. -func (txn *Txn) updateBM25Stats(ctx context.Context, attr string, docCountDelta int64, totalTermsDelta int64) error { +// applies the given deltas, and writes back as a direct Badger KV entry. +func (txn *Txn) updateBM25Stats(attr string, docCountDelta int64, totalTermsDelta int64) error { statsKey := x.BM25StatsKey(attr) - plist, err := txn.cache.GetFromDelta(statsKey) - if err != nil { - return err - } - - // Read existing stats from posting with uid=1. - var docCount, totalTerms uint64 - val, err := plist.Value(txn.StartTs) - if err == nil && val.Value != nil { - data, ok := val.Value.([]byte) - if ok && len(data) == 16 { - docCount = binary.BigEndian.Uint64(data[0:8]) - totalTerms = binary.BigEndian.Uint64(data[8:16]) - } - } + blob := txn.cache.ReadBM25Blob(statsKey) + docCount, totalTerms := bm25enc.DecodeStats(blob) // Apply deltas. if docCountDelta >= 0 { @@ -384,18 +331,8 @@ func (txn *Txn) updateBM25Stats(ctx context.Context, attr string, docCountDelta } } - // Write back stats. - statsBuf := make([]byte, 16) - binary.BigEndian.PutUint64(statsBuf[0:8], docCount) - binary.BigEndian.PutUint64(statsBuf[8:16], totalTerms) - edge := &pb.DirectedEdge{ - Entity: 1, - Attr: attr, - Value: statsBuf, - ValueType: pb.Posting_ValType(0), - Op: pb.DirectedEdge_SET, - } - return plist.addMutation(ctx, txn, edge) + txn.cache.WriteBM25Blob(statsKey, bm25enc.EncodeStats(docCount, totalTerms)) + return nil } // countParams is sent to updateCount function. It is used to update the count index. diff --git a/posting/list.go b/posting/list.go index 1c0c7a0fc55..5420a69a157 100644 --- a/posting/list.go +++ b/posting/list.go @@ -60,6 +60,8 @@ const ( BitCompletePosting byte = 0x08 // BitEmptyPosting signals that the value stores an empty posting list. BitEmptyPosting byte = 0x10 + // BitBM25Data signals that the value stores BM25 index data (direct KV, not a posting list). + BitBM25Data byte = 0x20 ) // List stores the in-memory representation of a posting list. diff --git a/posting/lists.go b/posting/lists.go index a4bc4fb355b..0bd9848de23 100644 --- a/posting/lists.go +++ b/posting/lists.go @@ -76,6 +76,10 @@ type LocalCache struct { // plists are posting lists in memory. They can be discarded to reclaim space. plists map[string]*List + + // bm25Writes buffers BM25 direct KV writes (key → encoded blob). + // These bypass the posting list infrastructure entirely. + bm25Writes map[string][]byte } // struct to implement LocalCache interface from vector-indexer @@ -135,6 +139,7 @@ func NewLocalCache(startTs uint64) *LocalCache { deltas: make(map[string][]byte), plists: make(map[string]*List), maxVersions: make(map[string]uint64), + bm25Writes: make(map[string][]byte), } } @@ -144,6 +149,57 @@ func NoCache(startTs uint64) *LocalCache { return &LocalCache{startTs: startTs} } +// ReadBM25Blob returns the BM25 blob for the given key. +// It checks the in-memory buffer first (read-your-own-writes), +// then falls back to reading from pstore at startTs. +func (lc *LocalCache) ReadBM25Blob(key []byte) []byte { + lc.RLock() + if blob, ok := lc.bm25Writes[string(key)]; ok { + lc.RUnlock() + return blob + } + lc.RUnlock() + + // Fall back to Badger. + txn := pstore.NewTransactionAt(lc.startTs, false) + defer txn.Discard() + item, err := txn.Get(key) + if err != nil { + return nil + } + val, err := item.ValueCopy(nil) + if err != nil { + return nil + } + return val +} + +// WriteBM25Blob buffers a BM25 blob write for the given key. +func (lc *LocalCache) WriteBM25Blob(key []byte, blob []byte) { + lc.Lock() + defer lc.Unlock() + if lc.bm25Writes == nil { + lc.bm25Writes = make(map[string][]byte) + } + lc.bm25Writes[string(key)] = blob +} + +// ReadBM25BlobAt reads a BM25 blob from pstore at the given read timestamp. +// This is used by the query read path (worker/task.go). +func ReadBM25BlobAt(key []byte, readTs uint64) []byte { + txn := pstore.NewTransactionAt(readTs, false) + defer txn.Discard() + item, err := txn.Get(key) + if err != nil { + return nil + } + val, err := item.ValueCopy(nil) + if err != nil { + return nil + } + return val +} + func (lc *LocalCache) UpdateCommitTs(commitTs uint64) { lc.Lock() defer lc.Unlock() diff --git a/posting/mvcc.go b/posting/mvcc.go index 81c5e375553..3b9510ef6bb 100644 --- a/posting/mvcc.go +++ b/posting/mvcc.go @@ -318,6 +318,21 @@ func (txn *Txn) CommitToDisk(writer *TxnWriter, commitTs uint64) error { return err } } + + // Flush BM25 direct KV writes. These are complete blobs (not deltas) + // and don't need rollup. + for key, blob := range cache.bm25Writes { + if err := writer.update(commitTs, func(btxn *badger.Txn) error { + return btxn.SetEntry(&badger.Entry{ + Key: []byte(key), + Value: blob, + UserMeta: BitBM25Data, + }) + }); err != nil { + return err + } + } + return nil } diff --git a/query/query.go b/query/query.go index 3025033e1e0..e241d18946b 100644 --- a/query/query.go +++ b/query/query.go @@ -7,6 +7,7 @@ package query import ( "context" + "encoding/binary" "fmt" "math" "sort" @@ -373,6 +374,19 @@ func getValue(tv *pb.TaskValue) (types.Val, error) { return val, nil } +func valToTaskValue(v types.Val) *pb.TaskValue { + data := types.ValueForType(types.BinaryID) + res := &pb.TaskValue{ValType: v.Tid.Enum(), Val: x.Nilbyte} + if v.Value == nil { + return res + } + if err := types.Marshal(v, &data); err != nil { + return res + } + res.Val = data.Value.([]byte) + return res +} + var ( // ErrEmptyVal is returned when a value is empty. ErrEmptyVal = errors.New("Query: harmless error, e.g. task.Val is nil") @@ -1369,6 +1383,9 @@ func (sg *SubGraph) valueVarAggregation(doneVars map[string]varValue, path []*Su case sg.Attr == "uid" && sg.Params.DoCount: // This is the count(uid) case. // We will do the computation later while constructing the result. + case sg.Attr == "bm25_score": + // bm25_score is a pseudo-predicate handled inline during children processing. + // Its valueMatrix is already populated. Nothing to aggregate. default: return errors.Errorf("Unhandled pb.node <%v> with parent <%v>", sg.Attr, parent.Attr) } @@ -2173,6 +2190,7 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { rch <- nil return } + var err error switch { case parent == nil && sg.SrcFunc != nil && sg.SrcFunc.Name == "uid": @@ -2275,6 +2293,30 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { sg.List = result.List sg.vectorMetrics = result.VectorMetrics + // If this is a BM25 root function, extract scores from ValueMatrix + // and store them in ParentVars for bm25_score pseudo-predicate children. + if sg.SrcFunc != nil && sg.SrcFunc.Name == "bm25" && len(result.UidMatrix) > 0 && + len(result.ValueMatrix) > 0 { + bm25Scores := types.NewShardedMap() + uids := result.UidMatrix[0].GetUids() + for i, uid := range uids { + if i < len(result.ValueMatrix) && len(result.ValueMatrix[i].Values) > 0 { + tv := result.ValueMatrix[i].Values[0] + if len(tv.Val) == 8 { + score := math.Float64frombits(binary.LittleEndian.Uint64(tv.Val)) + bm25Scores.Set(uid, types.Val{ + Tid: types.FloatID, + Value: score, + }) + } + } + } + if sg.Params.ParentVars == nil { + sg.Params.ParentVars = make(map[string]varValue) + } + sg.Params.ParentVars["__bm25_scores__"] = varValue{Vals: bm25Scores} + } + if sg.Params.DoCount { if len(sg.Filters) == 0 { // If there is a filter, we need to do more work to get the actual count. @@ -2452,6 +2494,28 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { } child.SrcUIDs = sg.DestUIDs // Make the connection. + + // Handle bm25_score pseudo-predicate: populate valueMatrix from parent's + // BM25 scores. Mark IsInternal so populateUidValVar case 4 (value variable) + // fires instead of case 3 (UID variable). + if child.Attr == "bm25_score" { + if bm25Var, ok := child.Params.ParentVars["__bm25_scores__"]; ok && bm25Var.Vals != nil { + child.valueMatrix = make([]*pb.ValueList, len(child.SrcUIDs.GetUids())) + for j, uid := range child.SrcUIDs.GetUids() { + if val, okv := bm25Var.Vals.Get(uid); okv { + child.valueMatrix[j] = &pb.ValueList{ + Values: []*pb.TaskValue{valToTaskValue(val)}, + } + } else { + child.valueMatrix[j] = &pb.ValueList{} + } + } + } + child.DestUIDs = &pb.List{} + child.Params.IsInternal = true + continue + } + if child.IsInternal() { // We dont have to execute these nodes. continue diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index dcceece7428..cdb235be36f 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -221,3 +221,82 @@ func TestBM25Pagination(t *testing.T) { // Doc 503 "fox fox fox" should be the top result. require.Contains(t, js, "fox fox fox") } + +func TestBM25ScoreOrdering(t *testing.T) { + // Use the bm25_score pseudo-predicate with var block to order results by score. + query := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score), first: 1) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + // "fox fox fox" (doc 503) has the highest BM25 score (tf=3, shortest doc). + require.Contains(t, js, "fox fox fox") +} + +func TestBM25ScoreOrderingMultiTerm(t *testing.T) { + // Multi-term query with score ordering: "quick lazy" should rank doc 501 highest + // since it contains both terms. + query := ` + { + var(func: bm25(description_bm25, "quick lazy")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score), first: 1) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "quick brown fox jumps over the lazy dog") +} + +func TestBM25ScoreOrderingAllResults(t *testing.T) { + // Verify all results are returned in score-descending order via val(score). + query := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + // All fox-containing docs should appear. + require.Contains(t, js, "fox fox fox") + require.Contains(t, js, "quick brown fox jumps") + // Score values should be present. + require.Contains(t, js, "val(score)") +} + +func TestBM25ScoreWithPagination(t *testing.T) { + // Use offset with score ordering. + query := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score), first: 1, offset: 1) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should return the second-highest scored document (not "fox fox fox"). + require.NotContains(t, js, "fox fox fox") + require.Contains(t, js, "fox") +} diff --git a/worker/task.go b/worker/task.go index 2fbd65acca4..fbc3189a42b 100644 --- a/worker/task.go +++ b/worker/task.go @@ -30,6 +30,7 @@ import ( "github.com/dgraph-io/dgraph/v25/algo" "github.com/dgraph-io/dgraph/v25/conn" "github.com/dgraph-io/dgraph/v25/posting" + "github.com/dgraph-io/dgraph/v25/posting/bm25enc" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" ctask "github.com/dgraph-io/dgraph/v25/task" @@ -1272,31 +1273,11 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error return nil } - // 3. Read corpus stats. + // 3. Read corpus stats from direct Badger KV. statsKey := x.BM25StatsKey(attr) - statsPl, err := qs.cache.Get(statsKey) - if err != nil { - // No stats means no documents indexed yet. - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) - return nil - } - statsVal, err := statsPl.Value(q.ReadTs) - if err != nil || statsVal.Value == nil { - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) - return nil - } - statsData, ok := statsVal.Value.([]byte) - if !ok || len(statsData) != 16 { - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) - return nil - } - docCount := binary.BigEndian.Uint64(statsData[0:8]) - totalTerms := binary.BigEndian.Uint64(statsData[8:16]) - if docCount == 0 { - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) - return nil - } - if totalTerms == 0 { + statsBlob := posting.ReadBM25BlobAt(statsKey, q.ReadTs) + docCount, totalTerms := bm25enc.DecodeStats(statsBlob) + if docCount == 0 || totalTerms == 0 { args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) return nil } @@ -1312,7 +1293,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - // 4. For each query token, read the posting list and collect term info. + // 4. For each query token, read the BM25 term blob and collect term info. type termInfo struct { idf float64 uidTFs map[uint64]uint32 @@ -1321,39 +1302,27 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error for _, token := range queryTokens { key := x.BM25IndexKey(attr, token) - pl, err := qs.cache.Get(key) - if err != nil { + blob := posting.ReadBM25BlobAt(key, q.ReadTs) + entries := bm25enc.Decode(blob) + if len(entries) == 0 { continue } ti := &termInfo{uidTFs: make(map[uint64]uint32)} - var df float64 - err = pl.Iterate(q.ReadTs, 0, func(p *pb.Posting) error { - df++ - // When used as filter, only collect TF for UIDs in the filter set. + df := float64(len(entries)) + for _, e := range entries { if filterSet != nil { - if _, ok := filterSet[p.Uid]; !ok { - return nil - } - } - tf := uint32(1) - for _, f := range p.Facets { - if f.Key == "tf" && len(f.Value) >= 4 { - tf = binary.BigEndian.Uint32(f.Value[:4]) - break + if _, ok := filterSet[e.UID]; !ok { + continue } } - ti.uidTFs[p.Uid] = tf - return nil - }) - if err != nil { - continue + ti.uidTFs[e.UID] = e.Value } ti.idf = math.Log1p((N - df + 0.5) / (df + 0.5)) termInfos[token] = ti } - // 5. Read doc lengths for all UIDs seen. + // 5. Read doc lengths for all UIDs seen using binary search on the doclen blob. allUids := make(map[uint64]struct{}) for _, ti := range termInfos { for uid := range ti.uidTFs { @@ -1361,28 +1330,15 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - docLens := make(map[uint64]uint32) dlKey := x.BM25DocLenKey(attr) - dlPl, err := qs.cache.Get(dlKey) - if err == nil { - remaining := len(allUids) - _ = dlPl.Iterate(q.ReadTs, 0, func(p *pb.Posting) error { - if remaining == 0 { - return posting.ErrStopIteration - } - if _, needed := allUids[p.Uid]; needed { - dl := uint32(1) - for _, f := range p.Facets { - if f.Key == "dl" && len(f.Value) >= 4 { - dl = binary.BigEndian.Uint32(f.Value[:4]) - break - } - } - docLens[p.Uid] = dl - remaining-- - } - return nil - }) + dlBlob := posting.ReadBM25BlobAt(dlKey, q.ReadTs) + dlEntries := bm25enc.Decode(dlBlob) + + docLens := make(map[uint64]uint32, len(allUids)) + for uid := range allUids { + if v, ok := bm25enc.Search(dlEntries, uid); ok { + docLens[uid] = v + } } // 6. Compute final BM25 scores. @@ -1408,7 +1364,6 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error for uid, score := range scores { results = append(results, uidScore{uid: uid, score: score}) } - // Sort by score descending for ordering, then collect UIDs. sort.Slice(results, func(i, j int) bool { if results[i].score != results[j].score { return results[i].score > results[j].score @@ -1428,14 +1383,26 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - // Build output UIDs sorted by UID (ascending) as required by the query pipeline. + // Build output: UIDs sorted ascending (required by query pipeline) + // and ValueMatrix with aligned scores (for bm25_score pseudo-predicate). + sort.Slice(results, func(i, j int) bool { return results[i].uid < results[j].uid }) uids := make([]uint64, len(results)) for i, r := range results { uids[i] = r.uid } - sort.Slice(uids, func(i, j int) bool { return uids[i] < uids[j] }) - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) + + // Populate ValueMatrix with BM25 scores aligned to UIDs. + // Each entry is a ValueList with a single float64 value. + scoreValues := make([]*pb.ValueList, len(results)) + for i, r := range results { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, math.Float64bits(r.score)) + scoreValues[i] = &pb.ValueList{ + Values: []*pb.TaskValue{{Val: buf, ValType: pb.Posting_ValType(pb.Posting_FLOAT)}}, + } + } + args.out.ValueMatrix = append(args.out.ValueMatrix, scoreValues...) return nil } From de8c9e6b4b5b7de164e3319ec4cf313969f89b9c Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 4 Mar 2026 23:51:09 -0500 Subject: [PATCH 04/12] test(bm25): add 15 integration tests for mutation scenarios and edge cases Cover incremental add/update/delete, IDF score stability as corpus grows, large corpus pagination, unicode, stopwords, uid filtering, score validation, and concurrent batch adds. Co-Authored-By: Claude Opus 4.6 --- query/query_bm25_test.go | 535 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 535 insertions(+) diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index cdb235be36f..6f469220ab5 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -10,6 +10,10 @@ package query import ( "context" + "encoding/json" + "fmt" + "math" + "strings" "testing" "github.com/stretchr/testify/require" @@ -300,3 +304,534 @@ func TestBM25ScoreWithPagination(t *testing.T) { require.NotContains(t, js, "fox fox fox") require.Contains(t, js, "fox") } + +// parseScoresFromJSON extracts uid → score from JSON responses containing val(score). +func parseScoresFromJSON(t *testing.T, js string) map[string]float64 { + t.Helper() + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + Score float64 `json:"val(score)"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + scores := make(map[string]float64) + for _, item := range resp.Data.Me { + scores[item.UID] = item.Score + } + return scores +} + +func TestBM25IncrementalAddBatch(t *testing.T) { + batch1 := ` + <600> "alpha bravo charlie" . + <601> "delta echo foxtrot" . + ` + batch2 := ` + <602> "golf hotel india" . + <603> "juliet kilo lima" . + <604> "mike november oscar" . + ` + batch3 := ` + <605> "papa quebec romeo" . + <606> "sierra tango uniform" . + <607> "victor whiskey xray" . + ` + cleanup := func() { + deleteTriplesInCluster(` + <600> * . + <601> * . + <602> * . + <603> * . + <604> * . + <605> * . + <606> * . + <607> * . + `) + } + t.Cleanup(cleanup) + + countQuery := ` + { + me(func: bm25(description_bm25, "alpha bravo delta echo golf juliet mike papa sierra victor")) { + count(uid) + } + } + ` + + // Batch 1: add 2 docs. + require.NoError(t, addTriplesToCluster(batch1)) + js := processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":2`) + + // Batch 2: add 3 more docs → total 5. + require.NoError(t, addTriplesToCluster(batch2)) + js = processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":5`) + + // Batch 3: add 3 more docs → total 8. + require.NoError(t, addTriplesToCluster(batch3)) + js = processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":8`) + + // Verify specific new UIDs are searchable. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "whiskey")) { uid } }`) + require.Contains(t, js, `"0x25e"`) // 606 +} + +func TestBM25CorpusStatsAffectIDF(t *testing.T) { + // Capture baseline score for "fox" query. + scoreQuery := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + jsBefore := processQueryNoErr(t, scoreQuery) + scoresBefore := parseScoresFromJSON(t, jsBefore) + require.NotEmpty(t, scoresBefore, "baseline should have fox results") + + // Add 10 non-fox docs → N grows, df("fox") stays same → IDF should increase. + var triples string + for i := 610; i < 620; i++ { + triples += fmt.Sprintf(`<%d> "completely unrelated document about cats and dogs number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + var del string + for i := 610; i < 620; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + jsAfter := processQueryNoErr(t, scoreQuery) + scoresAfter := parseScoresFromJSON(t, jsAfter) + + // Compare score for UID 503 ("fox fox fox") — should increase. + uid503 := "0x1f7" + before, ok1 := scoresBefore[uid503] + after, ok2 := scoresAfter[uid503] + require.True(t, ok1 && ok2, "UID 503 should appear in both before and after results") + require.Greater(t, after, before, + "IDF should increase when corpus grows with non-matching docs (before=%f, after=%f)", before, after) +} + +func TestBM25DocumentUpdate(t *testing.T) { + // Add a doc with lots of "fox". + require.NoError(t, addTriplesToCluster(`<620> "fox fox fox fox" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<620> * .`) + }) + + // Should rank top for "fox". + js := processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + } + }`) + require.Contains(t, js, `"0x26c"`) // 620 + + // Update to remove "fox", add "cat". + deleteTriplesInCluster(`<620> "fox fox fox fox" .`) + require.NoError(t, addTriplesToCluster(`<620> "the cat sat on the mat" .`)) + + // Should no longer appear in "fox" results. + js = processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "fox")) { + uid + } + }`) + require.NotContains(t, js, `"0x26c"`) + + // Should appear in "cat" results. + js = processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "cat")) { + uid + } + }`) + require.Contains(t, js, `"0x26c"`) +} + +func TestBM25DocumentDeletion(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<625> "unique elephant term" .`)) + t.Cleanup(func() { + // Cleanup in case test fails before explicit delete. + deleteTriplesInCluster(`<625> * .`) + }) + + // Should find the elephant doc. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.Contains(t, js, `"0x271"`) // 625 + + // Delete it. + deleteTriplesInCluster(`<625> "unique elephant term" .`) + + // Should return empty for "elephant". + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.JSONEq(t, `{"data": {"me":[]}}`, js) + + // Baseline "fox" results should be unaffected. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "fox")) { uid } }`) + require.Contains(t, js, "fox") +} + +func TestBM25ScoreStabilityAsCorpusGrows(t *testing.T) { + scoreQuery := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + uid503 := "0x1f7" + + // Phase 1: baseline score. + js1 := processQueryNoErr(t, scoreQuery) + scores1 := parseScoresFromJSON(t, js1) + score1, ok := scores1[uid503] + require.True(t, ok, "UID 503 must appear in baseline") + + // Phase 2: add 5 fox docs → IDF decreases. + var foxTriples string + for i := 630; i < 635; i++ { + foxTriples += fmt.Sprintf(`<%d> "the fox runs quickly across the field number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(foxTriples)) + t.Cleanup(func() { + var del string + for i := 630; i < 640; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + js2 := processQueryNoErr(t, scoreQuery) + scores2 := parseScoresFromJSON(t, js2) + score2, ok := scores2[uid503] + require.True(t, ok, "UID 503 must appear after adding fox docs") + require.Greater(t, score1, score2, + "Adding fox docs should decrease IDF and thus score (phase1=%f, phase2=%f)", score1, score2) + + // Phase 3: add 5 non-fox docs → IDF increases relative to phase 2. + var nonFoxTriples string + for i := 635; i < 640; i++ { + nonFoxTriples += fmt.Sprintf(`<%d> "unrelated content about birds and fish number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(nonFoxTriples)) + + js3 := processQueryNoErr(t, scoreQuery) + scores3 := parseScoresFromJSON(t, js3) + score3, ok := scores3[uid503] + require.True(t, ok, "UID 503 must appear after adding non-fox docs") + require.Greater(t, score3, score2, + "Adding non-fox docs should increase IDF relative to phase2 (phase2=%f, phase3=%f)", score2, score3) +} + +func TestBM25LargeCorpus(t *testing.T) { + // Add 100 docs: 50 with "alpha", 50 with "beta". + var triples string + for i := 700; i < 750; i++ { + triples += fmt.Sprintf(`<%d> "alpha document content number %d with some padding words" . +`, i, i) + } + for i := 750; i < 800; i++ { + triples += fmt.Sprintf(`<%d> "beta document content number %d with some padding words" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + var del string + for i := 700; i < 800; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + // Count alpha docs. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "alpha")) { count(uid) } }`) + require.Contains(t, js, `"count":50`) + + // Count beta docs. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "beta")) { count(uid) } }`) + require.Contains(t, js, `"count":50`) + + // Union count: "alpha beta" should match all 100. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "alpha beta")) { count(uid) } }`) + require.Contains(t, js, `"count":100`) + + // Pagination: first:10, offset:40 for alpha should return 10 results. + js = processQueryNoErr(t, ` + { + var(func: bm25(description_bm25, "alpha")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score), first: 10, offset: 40) { + uid + } + }`) + var resp struct { + Data struct { + Me []struct{ UID string } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.Len(t, resp.Data.Me, 10, "pagination first:10 offset:40 should return exactly 10 results") +} + +func TestBM25EdgeCaseSingleCharTerm(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<640> "x y z" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<640> * .`) + }) + + // Single-char terms may or may not be indexed depending on tokenizer. + // Just verify no panic/error. + _, err := processQuery(context.Background(), t, ` + { + me(func: bm25(description_bm25, "x")) { + uid + } + }`) + require.NoError(t, err) +} + +func TestBM25EdgeCaseLongDocument(t *testing.T) { + // Build a ~500-word document with "fox" appearing once. + words := make([]string, 500) + for i := range words { + words[i] = "padding" + } + words[250] = "fox" + longDoc := strings.Join(words, " ") + + require.NoError(t, addTriplesToCluster(fmt.Sprintf(`<645> %q .`, longDoc))) + t.Cleanup(func() { + deleteTriplesInCluster(`<645> * .`) + }) + + // Get scores for "fox" query. + scoreQuery := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + uid503 := "0x1f7" // "fox fox fox" (doclen=3) + uid645 := "0x285" // long doc (doclen~500) + s503, ok1 := scores[uid503] + s645, ok2 := scores[uid645] + require.True(t, ok1, "UID 503 must appear in fox results") + require.True(t, ok2, "UID 645 must appear in fox results") + require.Greater(t, s503, s645, + "Short doc with high tf should score higher than long doc with low tf (503=%f, 645=%f)", s503, s645) +} + +func TestBM25EdgeCaseUnicode(t *testing.T) { + triples := ` + <650> "der schnelle braune Fuchs springt" . + <651> "le renard brun rapide saute" . + <652> "el zorro marrón rápido salta" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <650> * . + <651> * . + <652> * . + `) + }) + + // Query German term. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "Fuchs")) { uid } }`) + require.Contains(t, js, `"0x28a"`) // 650 + + // Query French term. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "renard")) { uid } }`) + require.Contains(t, js, `"0x28b"`) // 651 + + // Query Spanish term. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "zorro")) { uid } }`) + require.Contains(t, js, `"0x28c"`) // 652 +} + +func TestBM25EdgeCaseAllStopwordsDoc(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<655> "the a an is are was were" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<655> * .`) + }) + + // Query "the" — should return empty since "the" is a stopword. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "the")) { uid } }`) + require.NotContains(t, js, `"0x28f"`) // 655 should not appear + + // But the doc should exist via has(). + js = processQueryNoErr(t, ` + { + me(func: has(description_bm25)) @filter(uid(655)) { + uid + } + }`) + require.Contains(t, js, `"0x28f"`) +} + +func TestBM25WithUidFilter(t *testing.T) { + // BM25 root with uid filter to restrict results. + query := ` + { + me(func: bm25(description_bm25, "fox")) @filter(uid(501, 503)) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should contain only UIDs 501 and 503. + require.Contains(t, js, `"0x1f5"`) // 501 + require.Contains(t, js, `"0x1f7"`) // 503 + // Should NOT contain other fox docs like 502, 506, 507. + require.NotContains(t, js, `"0x1f6"`) // 502 + require.NotContains(t, js, `"0x1fa"`) // 506 +} + +func TestBM25ScoreValuesAreValidFloats(t *testing.T) { + scoreQuery := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + require.NotEmpty(t, scores, "should have at least one result") + + var prevScore float64 + first := true + // Iterate over results in order (they're orderdesc by score). + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + Score float64 `json:"val(score)"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + + for _, item := range resp.Data.Me { + score := item.Score + require.False(t, math.IsNaN(score), "score should not be NaN for uid %s", item.UID) + require.False(t, math.IsInf(score, 0), "score should not be Inf for uid %s", item.UID) + require.Greater(t, score, 0.0, "score should be positive for uid %s", item.UID) + + if !first { + require.GreaterOrEqual(t, prevScore, score, + "scores should be in descending order: %f >= %f", prevScore, score) + } + prevScore = score + first = false + } +} + +func TestBM25IncrementalAddThenDeleteThenReadd(t *testing.T) { + t.Cleanup(func() { + deleteTriplesInCluster(`<670> * .`) + }) + + // Phase 1: add with "elephant". + require.NoError(t, addTriplesToCluster(`<670> "elephant roams the savanna" .`)) + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.Contains(t, js, `"0x29e"`) // 670 + + // Phase 2: delete. + deleteTriplesInCluster(`<670> "elephant roams the savanna" .`) + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.NotContains(t, js, `"0x29e"`) + + // Phase 3: re-add with different content. + require.NoError(t, addTriplesToCluster(`<670> "penguin waddles on the ice" .`)) + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "penguin")) { uid } }`) + require.Contains(t, js, `"0x29e"`) + + // "elephant" should still not match 670. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.NotContains(t, js, `"0x29e"`) +} + +func TestBM25NonIndexedPredicateError(t *testing.T) { + // "name" predicate does not have @index(bm25). + query := ` + { + me(func: bm25(name, "alice")) { + uid + } + } + ` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25") +} + +func TestBM25ConcurrentBatchAdd(t *testing.T) { + // Add 5 batches of 4 docs each (UIDs 680-699) back-to-back. + t.Cleanup(func() { + var del string + for i := 680; i < 700; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + for batch := 0; batch < 5; batch++ { + var triples string + for j := 0; j < 4; j++ { + uid := 680 + batch*4 + j + triples += fmt.Sprintf(`<%d> "searchterm batch%d doc%d content here" . +`, uid, batch, j) + } + require.NoError(t, addTriplesToCluster(triples)) + } + + // All 20 docs should be findable. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "searchterm")) { count(uid) } }`) + require.Contains(t, js, `"count":20`) + + // Spot-check a doc from each batch. + for batch := 0; batch < 5; batch++ { + uid := 680 + batch*4 + hexUID := fmt.Sprintf(`"0x%x"`, uid) + term := fmt.Sprintf("batch%d", batch) + js = processQueryNoErr(t, fmt.Sprintf(`{ me(func: bm25(description_bm25, "%s")) { uid } }`, term)) + require.Contains(t, js, hexUID, "doc %d from batch %d should be searchable", uid, batch) + } +} From 41e43156eaa86f8f420a64ed5e85f5947953dcbe Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 07:46:43 -0500 Subject: [PATCH 05/12] test(bm25): add exact score verification, BM15 variant, and single-doc tests Addresses test coverage gaps identified during code review against ArangoDB's BM25 implementation: - TestBM25ExactScoreValues: validates numerical correctness of BM25 formula using b=0 to enable hand-computed expected scores - TestBM25BM15NoLengthNormalization: verifies b=0 disables length normalization and contrasts with default b=0.75 behavior - TestBM25SingleMatchingDocument: covers df=1 edge case with high IDF Co-Authored-By: Claude Opus 4.6 --- query/query_bm25_test.go | 181 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index 6f469220ab5..457c7b46452 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -835,3 +835,184 @@ func TestBM25ConcurrentBatchAdd(t *testing.T) { require.Contains(t, js, hexUID, "doc %d from batch %d should be searchable", uid, batch) } } + +// parseCorpusCount returns the total number of documents with the description_bm25 predicate. +func parseCorpusCount(t *testing.T) float64 { + t.Helper() + js := processQueryNoErr(t, `{ me(func: has(description_bm25)) { count(uid) } }`) + var resp struct { + Data struct { + Me []struct { + Count int `json:"count"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.NotEmpty(t, resp.Data.Me) + n := float64(resp.Data.Me[0].Count) + require.Greater(t, n, 0.0, "corpus must have documents") + return n +} + +func TestBM25ExactScoreValues(t *testing.T) { + // Exact score verification using b=0 (BM15 variant) to eliminate avgDL dependency. + // With b=0: score = idf * (k+1) * tf / (k + tf) + // This validates the core BM25 formula computes correct numerical values. + triples := ` + <850> "quasar quasar quasar" . + <851> "quasar nebula pulsar" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <850> * . + <851> * . + `) + }) + + N := parseCorpusCount(t) + + // Query "quasar" with b=0 so score depends only on tf, k, and IDF (not avgDL). + scoreQuery := ` + { + var(func: bm25(description_bm25, "quasar", "1.2", "0")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + k := 1.2 + df := 2.0 // both 850 and 851 contain "quasar" + idf := math.Log1p((N - df + 0.5) / (df + 0.5)) + + // Doc 850 "quasar quasar quasar": tf=3, b=0 → score = idf * 2.2 * 3 / 4.2 + expected850 := idf * (k + 1) * 3.0 / (k + 3.0) + // Doc 851 "quasar nebula pulsar": tf=1, b=0 → score = idf * 2.2 * 1 / 2.2 = idf + expected851 := idf * (k + 1) * 1.0 / (k + 1.0) + + actual850, ok := scores["0x352"] // 850 + require.True(t, ok, "UID 850 (0x352) must be in results") + actual851, ok := scores["0x353"] // 851 + require.True(t, ok, "UID 851 (0x353) must be in results") + + require.InEpsilon(t, expected850, actual850, 1e-6, + "Doc 850 score mismatch: expected %f, got %f (N=%f, df=%f, idf=%f)", + expected850, actual850, N, df, idf) + require.InEpsilon(t, expected851, actual851, 1e-6, + "Doc 851 score mismatch: expected %f, got %f (N=%f, df=%f, idf=%f)", + expected851, actual851, N, df, idf) + + // Verify ordering: higher tf should yield higher score. + require.Greater(t, actual850, actual851) +} + +func TestBM25BM15NoLengthNormalization(t *testing.T) { + // With b=0 (BM15 variant), document length should NOT affect the score. + // Two docs with the same term frequency but different lengths must score identically. + triples := ` + <860> "vortex" . + <861> "vortex alpha bravo charlie delta echo foxtrot golf hotel india" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <860> * . + <861> * . + `) + }) + + // Query with b=0: length normalization disabled. + scoreQuery := ` + { + var(func: bm25(description_bm25, "vortex", "1.2", "0")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + score860, ok1 := scores["0x35c"] // 860 + score861, ok2 := scores["0x35d"] // 861 + require.True(t, ok1, "UID 860 must be in results") + require.True(t, ok2, "UID 861 must be in results") + + // With b=0 and same tf=1, scores must be equal regardless of document length. + require.InDelta(t, score860, score861, 1e-9, + "b=0 should disable length normalization: short doc score=%f, long doc score=%f", + score860, score861) + + // Now verify that with default b=0.75, the shorter doc scores higher. + scoreQueryDefault := ` + { + var(func: bm25(description_bm25, "vortex")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js = processQueryNoErr(t, scoreQueryDefault) + scoresDefault := parseScoresFromJSON(t, js) + + defScore860, ok1 := scoresDefault["0x35c"] + defScore861, ok2 := scoresDefault["0x35d"] + require.True(t, ok1, "UID 860 must be in default results") + require.True(t, ok2, "UID 861 must be in default results") + require.Greater(t, defScore860, defScore861, + "With b=0.75, shorter doc (doclen=1) should score higher than longer doc (doclen=10)") +} + +func TestBM25SingleMatchingDocument(t *testing.T) { + // Edge case: a single document matching the query term (df=1). + // IDF should be high since the term is very rare. + triples := `<865> "aardvark" .` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(`<865> * .`) + }) + + N := parseCorpusCount(t) + + // Query with b=0 for exact verification. + scoreQuery := ` + { + var(func: bm25(description_bm25, "aardvark", "1.2", "0")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + require.Len(t, scores, 1, "exactly one document should match 'aardvark'") + + actual, ok := scores["0x361"] // 865 + require.True(t, ok, "UID 865 (0x361) must be in results") + + // With df=1, tf=1, b=0, k=1.2: + // idf = log1p((N - 1 + 0.5) / (1 + 0.5)) = log1p((N - 0.5) / 1.5) + // score = idf * 2.2 * 1 / (1.2 + 1) = idf * 2.2 / 2.2 = idf + k := 1.2 + df := 1.0 + idf := math.Log1p((N - df + 0.5) / (df + 0.5)) + expected := idf * (k + 1) * 1.0 / (k + 1.0) // simplifies to idf + + require.InEpsilon(t, expected, actual, 1e-6, + "Single-doc score mismatch: expected %f, got %f (N=%f, idf=%f)", + expected, actual, N, idf) + require.Greater(t, actual, 0.0, "score must be positive") + require.False(t, math.IsInf(actual, 0), "score must be finite") +} From 0a4adba170a7a503a47f4c3a5e6a30398aae2e0b Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 08:07:35 -0500 Subject: [PATCH 06/12] feat(bm25): add block storage infrastructure for segmented column stores Phase 1 of BM25 scaling plan. Introduces bm25block package with: - BlockMeta/Dir types for block directory encoding/decoding - SplitIntoBlocks: splits monolithic entry slices into 128-entry blocks - MergeAllBlocks: compacts overlapping blocks with dedup and tombstone removal - ComputeUBPre/SuffixMaxUBPre: WAND upper-bound precomputation - New key functions: BM25TermDirKey, BM25TermBlockKey, BM25DocLenDirKey, BM25DocLenBlockKey for block-addressed Badger KV storage 17 unit tests and benchmarks for the block storage format. Co-Authored-By: Claude Opus 4.6 --- posting/bm25block/bm25block.go | 261 ++++++++++++++++++++++++++++ posting/bm25block/bm25block_test.go | 258 +++++++++++++++++++++++++++ x/keys.go | 24 +++ 3 files changed, 543 insertions(+) create mode 100644 posting/bm25block/bm25block.go create mode 100644 posting/bm25block/bm25block_test.go diff --git a/posting/bm25block/bm25block.go b/posting/bm25block/bm25block.go new file mode 100644 index 00000000000..f529ed8fab8 --- /dev/null +++ b/posting/bm25block/bm25block.go @@ -0,0 +1,261 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Package bm25block provides block-based storage for BM25 index data. +// +// Instead of storing all postings for a term in a single blob, this package +// splits them into fixed-size blocks (~128 entries). Each block is stored as +// a separate Badger KV entry, and a lightweight directory indexes the blocks. +// +// This enables: +// - Selective I/O: queries only read blocks they need +// - WAND/Block-Max WAND: per-block upper bounds enable early termination +// - Efficient mutations: only the affected block is rewritten +package bm25block + +import ( + "encoding/binary" + "math" + "sort" + + "github.com/dgraph-io/dgraph/v25/posting/bm25enc" +) + +const ( + // TargetBlockSize is the ideal number of entries per block. + TargetBlockSize = 128 + // MaxBlockSize is the threshold at which a block is split. + MaxBlockSize = 256 + // DocLenBlockSize is the target entries per document-length block. + DocLenBlockSize = 512 + + // dirHeaderSize is 4 (blockCount) + 4 (nextID). + dirHeaderSize = 8 + // dirEntrySize is 8 (firstUID) + 4 (blockID) + 4 (count) + 4 (maxTF). + dirEntrySize = 20 +) + +// BlockMeta stores metadata for a single block in a directory. +type BlockMeta struct { + FirstUID uint64 + BlockID uint32 + Count uint32 + MaxTF uint32 +} + +// Dir is a block directory for a term's posting list or document-length list. +type Dir struct { + Blocks []BlockMeta + NextID uint32 // next available block ID +} + +// EncodeDir encodes a directory to bytes. Returns nil for an empty directory. +func EncodeDir(d *Dir) []byte { + if d == nil || len(d.Blocks) == 0 { + return nil + } + buf := make([]byte, dirHeaderSize+len(d.Blocks)*dirEntrySize) + binary.BigEndian.PutUint32(buf[0:4], uint32(len(d.Blocks))) + binary.BigEndian.PutUint32(buf[4:8], d.NextID) + off := dirHeaderSize + for _, b := range d.Blocks { + binary.BigEndian.PutUint64(buf[off:off+8], b.FirstUID) + binary.BigEndian.PutUint32(buf[off+8:off+12], b.BlockID) + binary.BigEndian.PutUint32(buf[off+12:off+16], b.Count) + binary.BigEndian.PutUint32(buf[off+16:off+20], b.MaxTF) + off += dirEntrySize + } + return buf +} + +// DecodeDir decodes a directory from bytes. Returns an empty Dir for nil/invalid input. +func DecodeDir(data []byte) *Dir { + if len(data) < dirHeaderSize { + return &Dir{} + } + count := binary.BigEndian.Uint32(data[0:4]) + nextID := binary.BigEndian.Uint32(data[4:8]) + if int(count)*dirEntrySize+dirHeaderSize > len(data) { + return &Dir{NextID: nextID} + } + blocks := make([]BlockMeta, count) + off := dirHeaderSize + for i := uint32(0); i < count; i++ { + blocks[i] = BlockMeta{ + FirstUID: binary.BigEndian.Uint64(data[off : off+8]), + BlockID: binary.BigEndian.Uint32(data[off+8 : off+12]), + Count: binary.BigEndian.Uint32(data[off+12 : off+16]), + MaxTF: binary.BigEndian.Uint32(data[off+16 : off+20]), + } + off += dirEntrySize + } + return &Dir{Blocks: blocks, NextID: nextID} +} + +// FindBlock returns the index of the block that should contain uid. +// Returns 0 if the directory is empty (caller should create first block). +func (d *Dir) FindBlock(uid uint64) int { + if len(d.Blocks) == 0 { + return 0 + } + // Binary search: find the last block where FirstUID <= uid. + i := sort.Search(len(d.Blocks), func(i int) bool { + return d.Blocks[i].FirstUID > uid + }) + if i > 0 { + return i - 1 + } + return 0 +} + +// AllocBlockID returns the next available block ID and increments the counter. +func (d *Dir) AllocBlockID() uint32 { + id := d.NextID + d.NextID++ + return id +} + +// UpdateBlockMeta recomputes metadata for the block at index idx from entries. +func (d *Dir) UpdateBlockMeta(idx int, entries []bm25enc.Entry) { + if idx < 0 || idx >= len(d.Blocks) || len(entries) == 0 { + return + } + d.Blocks[idx].FirstUID = entries[0].UID + d.Blocks[idx].Count = uint32(len(entries)) + var maxTF uint32 + for _, e := range entries { + if e.Value > maxTF { + maxTF = e.Value + } + } + d.Blocks[idx].MaxTF = maxTF +} + +// InsertBlockMeta inserts a new block at position idx. +func (d *Dir) InsertBlockMeta(idx int, meta BlockMeta) { + d.Blocks = append(d.Blocks, BlockMeta{}) + copy(d.Blocks[idx+1:], d.Blocks[idx:]) + d.Blocks[idx] = meta +} + +// RemoveBlockMeta removes the block at position idx. +func (d *Dir) RemoveBlockMeta(idx int) { + if idx < 0 || idx >= len(d.Blocks) { + return + } + d.Blocks = append(d.Blocks[:idx], d.Blocks[idx+1:]...) +} + +// SplitIntoBlocks splits a sorted entry slice into blocks of TargetBlockSize. +// Returns a new Dir and a map of blockID -> entries. +func SplitIntoBlocks(entries []bm25enc.Entry) (*Dir, map[uint32][]bm25enc.Entry) { + if len(entries) == 0 { + return &Dir{}, nil + } + dir := &Dir{} + blockMap := make(map[uint32][]bm25enc.Entry) + + for i := 0; i < len(entries); i += TargetBlockSize { + end := i + TargetBlockSize + if end > len(entries) { + end = len(entries) + } + block := entries[i:end] + blockID := dir.AllocBlockID() + + var maxTF uint32 + for _, e := range block { + if e.Value > maxTF { + maxTF = e.Value + } + } + + dir.Blocks = append(dir.Blocks, BlockMeta{ + FirstUID: block[0].UID, + BlockID: blockID, + Count: uint32(len(block)), + MaxTF: maxTF, + }) + // Make a copy so the caller owns the slice. + cp := make([]bm25enc.Entry, len(block)) + copy(cp, block) + blockMap[blockID] = cp + } + return dir, blockMap +} + +// MergeAllBlocks reads all block entries from a map (keyed by blockID), +// merges them into a single sorted slice, then re-splits into clean blocks. +func MergeAllBlocks(dir *Dir, readBlock func(blockID uint32) []bm25enc.Entry) (*Dir, map[uint32][]bm25enc.Entry) { + var all []bm25enc.Entry + for _, bm := range dir.Blocks { + entries := readBlock(bm.BlockID) + all = append(all, entries...) + } + // Sort by UID and deduplicate (keep last occurrence for same UID). + sort.Slice(all, func(i, j int) bool { return all[i].UID < all[j].UID }) + deduped := make([]bm25enc.Entry, 0, len(all)) + for i, e := range all { + if i > 0 && e.UID == all[i-1].UID { + deduped[len(deduped)-1] = e // overwrite with latest + continue + } + deduped = append(deduped, e) + } + // Remove tombstones (Value == 0). + live := deduped[:0] + for _, e := range deduped { + if e.Value > 0 { + live = append(live, e) + } + } + return SplitIntoBlocks(live) +} + +// ComputeUBPre computes the upper-bound pre-IDF BM25 contribution for a block +// given its maxTF and query parameters k and b. +// With dl=0 (best case for scoring): score = (maxTF*(k+1)) / (maxTF + k*(1-b)) +func ComputeUBPre(maxTF uint32, k, b float64) float64 { + if maxTF == 0 { + return 0 + } + tf := float64(maxTF) + return tf * (k + 1) / (tf + k*(1-b)) +} + +// SuffixMaxUBPre computes suffix maxima of UBPre values for WAND. +// suffixMax[i] = max(ubPre[i], ubPre[i+1], ..., ubPre[n-1]) +func SuffixMaxUBPre(dir *Dir, k, b float64) []float64 { + n := len(dir.Blocks) + if n == 0 { + return nil + } + suf := make([]float64, n) + suf[n-1] = ComputeUBPre(dir.Blocks[n-1].MaxTF, k, b) + for i := n - 2; i >= 0; i-- { + ub := ComputeUBPre(dir.Blocks[i].MaxTF, k, b) + suf[i] = math.Max(ub, suf[i+1]) + } + return suf +} + +// BlockMetaFromEntries computes a BlockMeta from entries. +func BlockMetaFromEntries(blockID uint32, entries []bm25enc.Entry) BlockMeta { + if len(entries) == 0 { + return BlockMeta{BlockID: blockID} + } + var maxTF uint32 + for _, e := range entries { + if e.Value > maxTF { + maxTF = e.Value + } + } + return BlockMeta{ + FirstUID: entries[0].UID, + BlockID: blockID, + Count: uint32(len(entries)), + MaxTF: maxTF, + } +} diff --git a/posting/bm25block/bm25block_test.go b/posting/bm25block/bm25block_test.go new file mode 100644 index 00000000000..a7cc26f493a --- /dev/null +++ b/posting/bm25block/bm25block_test.go @@ -0,0 +1,258 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package bm25block + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dgraph-io/dgraph/v25/posting/bm25enc" +) + +func TestDirRoundtrip(t *testing.T) { + dir := &Dir{ + NextID: 5, + Blocks: []BlockMeta{ + {FirstUID: 100, BlockID: 0, Count: 128, MaxTF: 10}, + {FirstUID: 500, BlockID: 1, Count: 128, MaxTF: 5}, + {FirstUID: 900, BlockID: 2, Count: 64, MaxTF: 20}, + }, + } + data := EncodeDir(dir) + got := DecodeDir(data) + require.Equal(t, dir.NextID, got.NextID) + require.Equal(t, dir.Blocks, got.Blocks) +} + +func TestDirRoundtripEmpty(t *testing.T) { + require.Nil(t, EncodeDir(nil)) + require.Nil(t, EncodeDir(&Dir{})) + + got := DecodeDir(nil) + require.Empty(t, got.Blocks) + got = DecodeDir([]byte{}) + require.Empty(t, got.Blocks) +} + +func TestDirRoundtripSingle(t *testing.T) { + dir := &Dir{ + NextID: 1, + Blocks: []BlockMeta{{FirstUID: 42, BlockID: 0, Count: 1, MaxTF: 3}}, + } + got := DecodeDir(EncodeDir(dir)) + require.Equal(t, dir.Blocks, got.Blocks) +} + +func TestFindBlock(t *testing.T) { + dir := &Dir{ + Blocks: []BlockMeta{ + {FirstUID: 100}, + {FirstUID: 500}, + {FirstUID: 900}, + }, + } + require.Equal(t, 0, dir.FindBlock(50)) // before first block + require.Equal(t, 0, dir.FindBlock(100)) // exact first + require.Equal(t, 0, dir.FindBlock(200)) // within first block + require.Equal(t, 1, dir.FindBlock(500)) // exact second + require.Equal(t, 1, dir.FindBlock(700)) // within second block + require.Equal(t, 2, dir.FindBlock(900)) // exact third + require.Equal(t, 2, dir.FindBlock(9999)) // beyond last block +} + +func TestFindBlockEmpty(t *testing.T) { + dir := &Dir{} + require.Equal(t, 0, dir.FindBlock(100)) +} + +func TestAllocBlockID(t *testing.T) { + dir := &Dir{NextID: 3} + require.Equal(t, uint32(3), dir.AllocBlockID()) + require.Equal(t, uint32(4), dir.AllocBlockID()) + require.Equal(t, uint32(5), dir.NextID) +} + +func TestSplitIntoBlocks(t *testing.T) { + // Create 300 entries. + entries := make([]bm25enc.Entry, 300) + for i := range entries { + entries[i] = bm25enc.Entry{UID: uint64(i + 1), Value: uint32(i%10 + 1)} + } + dir, blockMap := SplitIntoBlocks(entries) + + // Should split into ceil(300/128) = 3 blocks. + require.Len(t, dir.Blocks, 3) + require.Len(t, blockMap, 3) + + // First block: 128 entries. + require.Equal(t, uint32(128), dir.Blocks[0].Count) + require.Equal(t, uint64(1), dir.Blocks[0].FirstUID) + require.Len(t, blockMap[dir.Blocks[0].BlockID], 128) + + // Second block: 128 entries. + require.Equal(t, uint32(128), dir.Blocks[1].Count) + require.Equal(t, uint64(129), dir.Blocks[1].FirstUID) + + // Third block: 44 entries. + require.Equal(t, uint32(44), dir.Blocks[2].Count) + require.Equal(t, uint64(257), dir.Blocks[2].FirstUID) + + // NextID should be 3. + require.Equal(t, uint32(3), dir.NextID) +} + +func TestSplitIntoBlocksEmpty(t *testing.T) { + dir, blockMap := SplitIntoBlocks(nil) + require.Empty(t, dir.Blocks) + require.Nil(t, blockMap) +} + +func TestSplitIntoBlocksSmall(t *testing.T) { + entries := []bm25enc.Entry{{UID: 1, Value: 5}, {UID: 2, Value: 3}} + dir, blockMap := SplitIntoBlocks(entries) + require.Len(t, dir.Blocks, 1) + require.Equal(t, uint32(2), dir.Blocks[0].Count) + require.Equal(t, uint32(5), dir.Blocks[0].MaxTF) + require.Equal(t, entries, blockMap[0]) +} + +func TestUpdateBlockMeta(t *testing.T) { + dir := &Dir{ + Blocks: []BlockMeta{{FirstUID: 100, BlockID: 0, Count: 3, MaxTF: 5}}, + } + entries := []bm25enc.Entry{ + {UID: 50, Value: 2}, + {UID: 100, Value: 8}, + {UID: 200, Value: 3}, + {UID: 300, Value: 1}, + } + dir.UpdateBlockMeta(0, entries) + require.Equal(t, uint64(50), dir.Blocks[0].FirstUID) + require.Equal(t, uint32(4), dir.Blocks[0].Count) + require.Equal(t, uint32(8), dir.Blocks[0].MaxTF) +} + +func TestInsertRemoveBlockMeta(t *testing.T) { + dir := &Dir{ + Blocks: []BlockMeta{ + {FirstUID: 100, BlockID: 0}, + {FirstUID: 500, BlockID: 1}, + }, + } + dir.InsertBlockMeta(1, BlockMeta{FirstUID: 300, BlockID: 2}) + require.Len(t, dir.Blocks, 3) + require.Equal(t, uint64(300), dir.Blocks[1].FirstUID) + require.Equal(t, uint64(500), dir.Blocks[2].FirstUID) + + dir.RemoveBlockMeta(1) + require.Len(t, dir.Blocks, 2) + require.Equal(t, uint64(500), dir.Blocks[1].FirstUID) +} + +func TestComputeUBPre(t *testing.T) { + k, b := 1.2, 0.75 + + // maxTF=0 -> 0 + require.Equal(t, 0.0, ComputeUBPre(0, k, b)) + + // maxTF=1: 1 * 2.2 / (1 + 1.2*0.25) = 2.2 / 1.3 + expected := 2.2 / 1.3 + require.InEpsilon(t, expected, ComputeUBPre(1, k, b), 1e-9) + + // maxTF=10: 10 * 2.2 / (10 + 1.2*0.25) = 22 / 10.3 + expected = 22.0 / 10.3 + require.InEpsilon(t, expected, ComputeUBPre(10, k, b), 1e-9) + + // With b=0: score = tf*(k+1)/(tf+k) — no length normalization. + expected = 5.0 * 2.2 / (5.0 + 1.2) + require.InEpsilon(t, expected, ComputeUBPre(5, k, 0), 1e-9) +} + +func TestSuffixMaxUBPre(t *testing.T) { + dir := &Dir{ + Blocks: []BlockMeta{ + {MaxTF: 1}, + {MaxTF: 10}, + {MaxTF: 3}, + }, + } + k, b := 1.2, 0.75 + suf := SuffixMaxUBPre(dir, k, b) + require.Len(t, suf, 3) + + ub0 := ComputeUBPre(1, k, b) + ub1 := ComputeUBPre(10, k, b) + ub2 := ComputeUBPre(3, k, b) + + require.InEpsilon(t, math.Max(ub0, math.Max(ub1, ub2)), suf[0], 1e-9) + require.InEpsilon(t, math.Max(ub1, ub2), suf[1], 1e-9) + require.InEpsilon(t, ub2, suf[2], 1e-9) +} + +func TestSuffixMaxUBPreEmpty(t *testing.T) { + require.Nil(t, SuffixMaxUBPre(&Dir{}, 1.2, 0.75)) +} + +func TestMergeAllBlocks(t *testing.T) { + // Simulate overlapping blocks with a tombstone. + blocks := map[uint32][]bm25enc.Entry{ + 0: {{UID: 1, Value: 3}, {UID: 5, Value: 1}}, + 1: {{UID: 5, Value: 7}, {UID: 10, Value: 2}}, // UID 5 overrides + 2: {{UID: 15, Value: 0}, {UID: 20, Value: 4}}, // UID 15 is tombstone + } + dir := &Dir{ + Blocks: []BlockMeta{ + {FirstUID: 1, BlockID: 0, Count: 2}, + {FirstUID: 5, BlockID: 1, Count: 2}, + {FirstUID: 15, BlockID: 2, Count: 2}, + }, + NextID: 3, + } + newDir, newBlocks := MergeAllBlocks(dir, func(id uint32) []bm25enc.Entry { + return blocks[id] + }) + // After merge: UID 1(3), 5(7), 10(2), 20(4) — UID 15 removed (tombstone). + require.Len(t, newDir.Blocks, 1) // 4 entries fits in one block + require.Len(t, newBlocks, 1) + entries := newBlocks[newDir.Blocks[0].BlockID] + require.Len(t, entries, 4) + require.Equal(t, uint64(1), entries[0].UID) + require.Equal(t, uint32(3), entries[0].Value) + require.Equal(t, uint64(5), entries[1].UID) + require.Equal(t, uint32(7), entries[1].Value) + require.Equal(t, uint64(20), entries[3].UID) +} + +func TestBlockMetaFromEntries(t *testing.T) { + entries := []bm25enc.Entry{ + {UID: 10, Value: 2}, + {UID: 20, Value: 8}, + {UID: 30, Value: 1}, + } + meta := BlockMetaFromEntries(5, entries) + require.Equal(t, uint32(5), meta.BlockID) + require.Equal(t, uint64(10), meta.FirstUID) + require.Equal(t, uint32(3), meta.Count) + require.Equal(t, uint32(8), meta.MaxTF) +} + +func TestBlockMetaFromEntriesEmpty(t *testing.T) { + meta := BlockMetaFromEntries(0, nil) + require.Equal(t, uint32(0), meta.Count) +} + +func BenchmarkSplitIntoBlocks(b *testing.B) { + entries := make([]bm25enc.Entry, 100000) + for i := range entries { + entries[i] = bm25enc.Entry{UID: uint64(i*3 + 1), Value: uint32(i%100 + 1)} + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + SplitIntoBlocks(entries) + } +} diff --git a/x/keys.go b/x/keys.go index 23196fd89c9..0a23ba19c6a 100644 --- a/x/keys.go +++ b/x/keys.go @@ -310,6 +310,30 @@ func BM25StatsKey(attr string) []byte { return IndexKey(attr, BM25Prefix+"__stats__") } +// BM25TermDirKey generates the key for a BM25 term's block directory. +func BM25TermDirKey(attr, term string) []byte { + return IndexKey(attr, BM25Prefix+"__dir__"+term) +} + +// BM25TermBlockKey generates the key for an individual BM25 term posting block. +func BM25TermBlockKey(attr, term string, blockID uint32) []byte { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], blockID) + return IndexKey(attr, BM25Prefix+"__blk__"+term+string(buf[:])) +} + +// BM25DocLenDirKey generates the key for the BM25 document-length block directory. +func BM25DocLenDirKey(attr string) []byte { + return IndexKey(attr, BM25Prefix+"__dldir__") +} + +// BM25DocLenBlockKey generates the key for an individual BM25 document-length block. +func BM25DocLenBlockKey(attr string, segID uint32) []byte { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], segID) + return IndexKey(attr, BM25Prefix+"__dlblk__"+string(buf[:])) +} + // ParsedKey represents a key that has been parsed into its multiple attributes. type ParsedKey struct { Attr string From f08f372b35bb30426e5b7e23ecefd3211c1a8468 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 08:11:04 -0500 Subject: [PATCH 07/12] feat(bm25): segmented block writes and WAND/Block-Max WAND query path Phases 2-4 of BM25 scaling plan: Phase 2 - Segmented mutation path: - addBM25IndexMutations now writes to block-based storage - Each term's postings split into ~128-entry blocks with a directory - Blocks automatically split when exceeding 256 entries - Doc-length list also uses block-based storage - Block removal and directory cleanup on deletes Phase 3 - WAND top-k query path: - New bm25wand.go with listIter for block-based posting list iteration - WAND algorithm with min-heap for top-k early termination - Per-block upper bounds (UBPre) computed from maxTF at query time - Suffix-max UBPre for efficient threshold checking - Falls back to scoring all docs when no first: limit or offset is used Phase 4 - Block-Max WAND: - skipToWithBMW skips entire blocks whose UB + other terms can't beat theta - Avoids Badger reads for blocks that can't contribute to top-k - Enabled by default in handleBM25Search Co-Authored-By: Claude Opus 4.6 --- posting/index.go | 183 ++++++++++++++--- worker/bm25wand.go | 501 +++++++++++++++++++++++++++++++++++++++++++++ worker/task.go | 95 ++------- 3 files changed, 668 insertions(+), 111 deletions(-) create mode 100644 worker/bm25wand.go diff --git a/posting/index.go b/posting/index.go index 826355a3633..d2eadd904e0 100644 --- a/posting/index.go +++ b/posting/index.go @@ -28,6 +28,7 @@ import ( "github.com/dgraph-io/badger/v4" "github.com/dgraph-io/badger/v4/options" bpb "github.com/dgraph-io/badger/v4/pb" + "github.com/dgraph-io/dgraph/v25/posting/bm25block" "github.com/dgraph-io/dgraph/v25/posting/bm25enc" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" @@ -232,8 +233,9 @@ func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, tok } // addBM25IndexMutations handles index mutations for the BM25 tokenizer. -// It stores term frequencies, document lengths, and corpus statistics as direct -// Badger KV entries using compact varint encoding, bypassing posting lists. +// It stores term frequencies, document lengths, and corpus statistics using +// block-based storage: each term's postings and the doclen list are split into +// fixed-size blocks (~128 entries) with a lightweight directory for navigation. func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationInfo) error { attr := info.edge.Attr uid := info.edge.Entity @@ -261,45 +263,168 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn } if info.op == pb.DirectedEdge_DEL { - // For DELETE: remove uid from all BM25 term posting lists and doc length list. + // For DELETE: remove uid from all term blocks and doclen blocks. for term := range termFreqs { encodedTerm := string([]byte{tok.IdentBM25}) + term - key := x.BM25IndexKey(attr, encodedTerm) - blob := txn.cache.ReadBM25Blob(key) - entries := bm25enc.Decode(blob) - entries = bm25enc.Remove(entries, uid) - txn.cache.WriteBM25Blob(key, bm25enc.Encode(entries)) - } - // Remove doc length entry. - dlKey := x.BM25DocLenKey(attr) - blob := txn.cache.ReadBM25Blob(dlKey) - entries := bm25enc.Decode(blob) - entries = bm25enc.Remove(entries, uid) - txn.cache.WriteBM25Blob(dlKey, bm25enc.Encode(entries)) - - // Update corpus stats: decrement doc count and total terms. + txn.bm25BlockRemove(attr, encodedTerm, uid) + } + txn.bm25DocLenBlockRemove(attr, uid) return txn.updateBM25Stats(attr, -1, -int64(docLen)) } - // For SET: store term frequencies and doc length. + // For SET: upsert term frequencies and doc length into blocks. for term, tf := range termFreqs { encodedTerm := string([]byte{tok.IdentBM25}) + term - key := x.BM25IndexKey(attr, encodedTerm) - blob := txn.cache.ReadBM25Blob(key) - entries := bm25enc.Decode(blob) - entries = bm25enc.Upsert(entries, uid, tf) - txn.cache.WriteBM25Blob(key, bm25enc.Encode(entries)) + txn.bm25BlockUpsert(attr, encodedTerm, uid, tf) + } + txn.bm25DocLenBlockUpsert(attr, uid, docLen) + return txn.updateBM25Stats(attr, 1, int64(docLen)) +} + +// bm25BlockUpsert inserts or updates a (uid, value) entry in the block-based +// posting list for the given term. Handles block creation and splitting. +func (txn *Txn) bm25BlockUpsert(attr, encodedTerm string, uid uint64, value uint32) { + dirKey := x.BM25TermDirKey(attr, encodedTerm) + dirBlob := txn.cache.ReadBM25Blob(dirKey) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + // First entry for this term: create a single block. + blockID := dir.AllocBlockID() + entries := []bm25enc.Entry{{UID: uid, Value: value}} + blockKey := x.BM25TermBlockKey(attr, encodedTerm, blockID) + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.Blocks = append(dir.Blocks, bm25block.BlockMetaFromEntries(blockID, entries)) + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) + return + } + + // Find the target block, read it, upsert, and handle splits. + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25TermBlockKey(attr, encodedTerm, bm.BlockID) + blob := txn.cache.ReadBM25Blob(blockKey) + entries := bm25enc.Decode(blob) + entries = bm25enc.Upsert(entries, uid, value) + + if len(entries) > bm25block.MaxBlockSize { + // Split the block. + mid := len(entries) / 2 + left := entries[:mid] + right := entries[mid:] + + // Write left block (reuse existing blockID). + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(left)) + dir.UpdateBlockMeta(blockIdx, left) + + // Write right block (new blockID). + newBlockID := dir.AllocBlockID() + newBlockKey := x.BM25TermBlockKey(attr, encodedTerm, newBlockID) + txn.cache.WriteBM25Blob(newBlockKey, bm25enc.Encode(right)) + dir.InsertBlockMeta(blockIdx+1, bm25block.BlockMetaFromEntries(newBlockID, right)) + } else { + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.UpdateBlockMeta(blockIdx, entries) + } + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) +} + +// bm25BlockRemove removes a uid from the block-based posting list for the given term. +func (txn *Txn) bm25BlockRemove(attr, encodedTerm string, uid uint64) { + dirKey := x.BM25TermDirKey(attr, encodedTerm) + dirBlob := txn.cache.ReadBM25Blob(dirKey) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + return } - // Store document length. - dlKey := x.BM25DocLenKey(attr) - blob := txn.cache.ReadBM25Blob(dlKey) + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25TermBlockKey(attr, encodedTerm, bm.BlockID) + blob := txn.cache.ReadBM25Blob(blockKey) + entries := bm25enc.Decode(blob) + entries = bm25enc.Remove(entries, uid) + + if len(entries) == 0 { + // Block is empty; remove it from the directory. + txn.cache.WriteBM25Blob(blockKey, nil) + dir.RemoveBlockMeta(blockIdx) + } else { + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.UpdateBlockMeta(blockIdx, entries) + } + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) +} + +// bm25DocLenBlockUpsert inserts or updates a doc-length entry in the block-based +// document-length list. +func (txn *Txn) bm25DocLenBlockUpsert(attr string, uid uint64, docLen uint32) { + dirKey := x.BM25DocLenDirKey(attr) + dirBlob := txn.cache.ReadBM25Blob(dirKey) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + blockID := dir.AllocBlockID() + entries := []bm25enc.Entry{{UID: uid, Value: docLen}} + blockKey := x.BM25DocLenBlockKey(attr, blockID) + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.Blocks = append(dir.Blocks, bm25block.BlockMetaFromEntries(blockID, entries)) + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) + return + } + + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) + blob := txn.cache.ReadBM25Blob(blockKey) entries := bm25enc.Decode(blob) entries = bm25enc.Upsert(entries, uid, docLen) - txn.cache.WriteBM25Blob(dlKey, bm25enc.Encode(entries)) - // Update corpus stats: increment doc count by 1 and total terms by docLen. - return txn.updateBM25Stats(attr, 1, int64(docLen)) + if len(entries) > bm25block.MaxBlockSize { + mid := len(entries) / 2 + left := entries[:mid] + right := entries[mid:] + + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(left)) + dir.UpdateBlockMeta(blockIdx, left) + + newBlockID := dir.AllocBlockID() + newBlockKey := x.BM25DocLenBlockKey(attr, newBlockID) + txn.cache.WriteBM25Blob(newBlockKey, bm25enc.Encode(right)) + dir.InsertBlockMeta(blockIdx+1, bm25block.BlockMetaFromEntries(newBlockID, right)) + } else { + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.UpdateBlockMeta(blockIdx, entries) + } + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) +} + +// bm25DocLenBlockRemove removes a uid from the block-based document-length list. +func (txn *Txn) bm25DocLenBlockRemove(attr string, uid uint64) { + dirKey := x.BM25DocLenDirKey(attr) + dirBlob := txn.cache.ReadBM25Blob(dirKey) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + return + } + + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) + blob := txn.cache.ReadBM25Blob(blockKey) + entries := bm25enc.Decode(blob) + entries = bm25enc.Remove(entries, uid) + + if len(entries) == 0 { + txn.cache.WriteBM25Blob(blockKey, nil) + dir.RemoveBlockMeta(blockIdx) + } else { + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.UpdateBlockMeta(blockIdx, entries) + } + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) } // updateBM25Stats reads the current corpus statistics for a BM25-indexed attribute, diff --git a/worker/bm25wand.go b/worker/bm25wand.go new file mode 100644 index 00000000000..7fc9dc74ec1 --- /dev/null +++ b/worker/bm25wand.go @@ -0,0 +1,501 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package worker + +import ( + "container/heap" + "math" + "sort" + + "github.com/dgraph-io/dgraph/v25/posting" + "github.com/dgraph-io/dgraph/v25/posting/bm25block" + "github.com/dgraph-io/dgraph/v25/posting/bm25enc" + "github.com/dgraph-io/dgraph/v25/x" +) + +// listIter iterates over a term's block-based posting list for WAND scoring. +type listIter struct { + attr string + encodedTerm string + readTs uint64 + idf float64 + k, b float64 + + dir *bm25block.Dir + ubPreSuf []float64 // suffix max of UBPre values + blockIdx int // current block index in dir.Blocks + block []bm25enc.Entry // decoded current block + inBlockPos int // position within current block + + exhausted bool +} + +// newListIter creates a new iterator for a term's block-based posting list. +func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64) *listIter { + dirKey := x.BM25TermDirKey(attr, encodedTerm) + dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + return &listIter{exhausted: true} + } + + it := &listIter{ + attr: attr, + encodedTerm: encodedTerm, + readTs: readTs, + idf: idf, + k: k, + b: b, + dir: dir, + ubPreSuf: bm25block.SuffixMaxUBPre(dir, k, b), + blockIdx: -1, // will be advanced on first Next() + } + return it +} + +// currentDoc returns the UID at the current position. +func (it *listIter) currentDoc() uint64 { + if it.exhausted || it.block == nil || it.inBlockPos >= len(it.block) { + return math.MaxUint64 + } + return it.block[it.inBlockPos].UID +} + +// currentTF returns the term frequency at the current position. +func (it *listIter) currentTF() uint32 { + if it.exhausted || it.block == nil || it.inBlockPos >= len(it.block) { + return 0 + } + return it.block[it.inBlockPos].Value +} + +// remainingUB returns the IDF-weighted upper-bound score for the remaining postings. +func (it *listIter) remainingUB() float64 { + if it.exhausted || it.blockIdx >= len(it.ubPreSuf) { + return 0 + } + return it.idf * it.ubPreSuf[it.blockIdx] +} + +// blockUB returns the IDF-weighted upper-bound for the current block only. +func (it *listIter) blockUB() float64 { + if it.exhausted || it.blockIdx < 0 || it.blockIdx >= len(it.dir.Blocks) { + return 0 + } + return it.idf * bm25block.ComputeUBPre(it.dir.Blocks[it.blockIdx].MaxTF, it.k, it.b) +} + +// next advances to the next posting. Returns false if exhausted. +func (it *listIter) next() bool { + if it.exhausted { + return false + } + + // Try advancing within the current block. + if it.block != nil { + it.inBlockPos++ + if it.inBlockPos < len(it.block) { + return true + } + } + + // Move to the next block. + it.blockIdx++ + if it.blockIdx >= len(it.dir.Blocks) { + it.exhausted = true + return false + } + it.loadBlock(it.blockIdx) + return it.inBlockPos < len(it.block) +} + +// skipTo advances to the first posting with UID >= target. +// Returns false if exhausted. +func (it *listIter) skipTo(target uint64) bool { + if it.exhausted { + return false + } + + // If current doc is already >= target, no-op. + if it.block != nil && it.inBlockPos < len(it.block) && it.block[it.inBlockPos].UID >= target { + return true + } + + // Check if target might be in the current block. + if it.block != nil && it.blockIdx < len(it.dir.Blocks) { + lastInBlock := it.block[len(it.block)-1].UID + if target <= lastInBlock { + // Binary search within current block. + pos := sort.Search(len(it.block)-it.inBlockPos, func(i int) bool { + return it.block[it.inBlockPos+i].UID >= target + }) + it.inBlockPos += pos + if it.inBlockPos < len(it.block) { + return true + } + } + } + + // Find the right block using the directory. + blockIdx := it.findBlockForTarget(target) + if blockIdx >= len(it.dir.Blocks) { + it.exhausted = true + return false + } + + it.blockIdx = blockIdx + it.loadBlock(blockIdx) + + // Binary search within the block. + pos := sort.Search(len(it.block), func(i int) bool { + return it.block[i].UID >= target + }) + it.inBlockPos = pos + if pos >= len(it.block) { + // Target is beyond this block; try the next. + return it.next() + } + return true +} + +// skipToWithBMW is like skipTo but uses Block-Max WAND to skip entire blocks +// whose upper bounds can't beat the given threshold. +func (it *listIter) skipToWithBMW(target uint64, theta float64, otherUB float64) bool { + if it.exhausted { + return false + } + + // If current doc is already >= target, no-op. + if it.block != nil && it.inBlockPos < len(it.block) && it.block[it.inBlockPos].UID >= target { + return true + } + + blockIdx := it.findBlockForTarget(target) + for blockIdx < len(it.dir.Blocks) { + // Check if this block's UB combined with other terms can beat theta. + blockUB := it.idf * bm25block.ComputeUBPre(it.dir.Blocks[blockIdx].MaxTF, it.k, it.b) + if blockUB+otherUB > theta { + // This block might have a winner; load and search it. + it.blockIdx = blockIdx + it.loadBlock(blockIdx) + pos := sort.Search(len(it.block), func(i int) bool { + return it.block[i].UID >= target + }) + it.inBlockPos = pos + if pos < len(it.block) { + return true + } + // Fall through to next block. + } + blockIdx++ + // Update target to the next block's firstUID (we've already skipped past target). + if blockIdx < len(it.dir.Blocks) { + target = it.dir.Blocks[blockIdx].FirstUID + } + } + it.exhausted = true + return false +} + +// findBlockForTarget returns the block index that should contain target. +func (it *listIter) findBlockForTarget(target uint64) int { + blocks := it.dir.Blocks + idx := sort.Search(len(blocks), func(i int) bool { + return blocks[i].FirstUID > target + }) + if idx > 0 { + return idx - 1 + } + return 0 +} + +// loadBlock decodes the block at the given directory index. +func (it *listIter) loadBlock(idx int) { + bm := it.dir.Blocks[idx] + blockKey := x.BM25TermBlockKey(it.attr, it.encodedTerm, bm.BlockID) + blob := posting.ReadBM25BlobAt(blockKey, it.readTs) + it.block = bm25enc.Decode(blob) + it.inBlockPos = 0 +} + +// scoredDoc holds a UID and its BM25 score for the min-heap. +type scoredDoc struct { + uid uint64 + score float64 +} + +// topKHeap is a min-heap of scored documents for top-k tracking. +type topKHeap struct { + docs []scoredDoc + k int +} + +func (h *topKHeap) Len() int { return len(h.docs) } +func (h *topKHeap) Less(i, j int) bool { return h.docs[i].score < h.docs[j].score } +func (h *topKHeap) Swap(i, j int) { h.docs[i], h.docs[j] = h.docs[j], h.docs[i] } +func (h *topKHeap) Push(x interface{}) { h.docs = append(h.docs, x.(scoredDoc)) } +func (h *topKHeap) Pop() interface{} { + old := h.docs + n := len(old) + item := old[n-1] + h.docs = old[:n-1] + return item +} + +// threshold returns the minimum score in the heap (the score to beat). +func (h *topKHeap) threshold() float64 { + if len(h.docs) < h.k { + return 0 + } + return h.docs[0].score +} + +// tryPush adds a doc if it beats the current threshold. Returns true if the +// threshold changed. +func (h *topKHeap) tryPush(uid uint64, score float64) bool { + if len(h.docs) < h.k { + heap.Push(h, scoredDoc{uid: uid, score: score}) + return len(h.docs) == h.k // threshold only meaningful once heap is full + } + if score > h.docs[0].score { + h.docs[0] = scoredDoc{uid: uid, score: score} + heap.Fix(h, 0) + return true + } + return false +} + +// sorted returns all docs sorted by score descending, then UID ascending. +func (h *topKHeap) sorted() []scoredDoc { + result := make([]scoredDoc, len(h.docs)) + copy(result, h.docs) + sort.Slice(result, func(i, j int) bool { + if result[i].score != result[j].score { + return result[i].score > result[j].score + } + return result[i].uid < result[j].uid + }) + return result +} + +// bm25Score computes the BM25 score for a single term occurrence. +func bm25Score(idf, tf, dl, avgDL, k, b float64) float64 { + return idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) +} + +// lookupDocLen looks up a single UID's document length from the block-based doclen store. +func lookupDocLen(attr string, uid, readTs uint64) float64 { + dirKey := x.BM25DocLenDirKey(attr) + dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + return 1.0 // fallback + } + + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) + blob := posting.ReadBM25BlobAt(blockKey, readTs) + entries := bm25enc.Decode(blob) + if v, ok := bm25enc.Search(entries, uid); ok { + return float64(v) + } + return 1.0 +} + +// wandSearch performs a WAND top-k search over block-based posting lists. +// If topK <= 0, it scores all matching documents (no early termination). +func wandSearch(attr string, readTs uint64, queryTokens []string, + k, b, avgDL, N float64, topK int, filterSet map[uint64]struct{}, + useBMW bool) []scoredDoc { + + // Build iterators for each query term. + var iters []*listIter + for _, token := range queryTokens { + dirKey := x.BM25TermDirKey(attr, token) + dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) + dir := bm25block.DecodeDir(dirBlob) + if len(dir.Blocks) == 0 { + continue + } + + // Compute df from directory. + var df uint64 + for _, bm := range dir.Blocks { + df += uint64(bm.Count) + } + idf := math.Log1p((N - float64(df) + 0.5) / (float64(df) + 0.5)) + + it := newListIter(attr, token, readTs, idf, k, b) + if !it.exhausted { + it.next() // prime the iterator + if !it.exhausted { + iters = append(iters, it) + } + } + } + + if len(iters) == 0 { + return nil + } + + // If no top-k limit, score all matching documents. + if topK <= 0 { + return scoreAllDocs(iters, attr, readTs, k, b, avgDL, filterSet) + } + + // WAND algorithm with top-k heap. + h := &topKHeap{k: topK} + heap.Init(h) + + for { + // Remove exhausted iterators. + active := iters[:0] + for _, it := range iters { + if !it.exhausted { + active = append(active, it) + } + } + iters = active + if len(iters) == 0 { + break + } + + // Sort iterators by currentDoc ascending. + sort.Slice(iters, func(i, j int) bool { + return iters[i].currentDoc() < iters[j].currentDoc() + }) + + theta := h.threshold() + + // Find pivot: accumulate UBs until they exceed theta. + var sumUB float64 + pivot := -1 + var pivotDoc uint64 + for i, it := range iters { + sumUB += it.remainingUB() + if sumUB > theta { + pivot = i + pivotDoc = it.currentDoc() + break + } + } + if pivot == -1 { + break // sum of all UBs can't beat theta + } + + // Advance all iterators before pivot to pivotDoc. + allAtPivot := true + for i := 0; i < pivot; i++ { + if iters[i].currentDoc() < pivotDoc { + var ok bool + if useBMW { + // Compute other UBs for BMW skipping. + var otherUB float64 + for j, jt := range iters { + if j != i { + otherUB += jt.remainingUB() + } + } + ok = iters[i].skipToWithBMW(pivotDoc, theta, otherUB) + } else { + ok = iters[i].skipTo(pivotDoc) + } + if !ok { + allAtPivot = false + break + } + if iters[i].currentDoc() != pivotDoc { + allAtPivot = false + } + } + } + + if !allAtPivot { + continue // re-evaluate after advances + } + + // All iterators up to pivot are at pivotDoc. Score the candidate. + if filterSet != nil { + if _, ok := filterSet[pivotDoc]; !ok { + // Skip this doc (filtered out). Advance all iters at pivotDoc. + for _, it := range iters { + if it.currentDoc() == pivotDoc { + it.next() + } + } + continue + } + } + + dl := lookupDocLen(attr, pivotDoc, readTs) + var score float64 + for _, it := range iters { + if it.currentDoc() == pivotDoc { + tf := float64(it.currentTF()) + score += bm25Score(it.idf, tf, dl, avgDL, k, b) + } + } + h.tryPush(pivotDoc, score) + + // Advance all iterators at pivotDoc. + for _, it := range iters { + if it.currentDoc() == pivotDoc { + it.next() + } + } + } + + return h.sorted() +} + +// scoreAllDocs scores every matching document without early termination. +// Used when no top-k limit is specified (the original behavior). +func scoreAllDocs(iters []*listIter, attr string, readTs uint64, + k, b, avgDL float64, filterSet map[uint64]struct{}) []scoredDoc { + + // Collect all (uid, term) matches. + type termMatch struct { + idf float64 + tf uint32 + } + matches := make(map[uint64][]termMatch) + + for _, it := range iters { + for !it.exhausted { + uid := it.currentDoc() + tf := it.currentTF() + if filterSet == nil { + matches[uid] = append(matches[uid], termMatch{idf: it.idf, tf: tf}) + } else if _, ok := filterSet[uid]; ok { + matches[uid] = append(matches[uid], termMatch{idf: it.idf, tf: tf}) + } + it.next() + } + } + + // Score all matching documents. + results := make([]scoredDoc, 0, len(matches)) + for uid, terms := range matches { + dl := lookupDocLen(attr, uid, readTs) + var score float64 + for _, tm := range terms { + score += bm25Score(tm.idf, float64(tm.tf), dl, avgDL, k, b) + } + results = append(results, scoredDoc{uid: uid, score: score}) + } + + // Sort by score descending, then UID ascending. + sort.Slice(results, func(i, j int) bool { + if results[i].score != results[j].score { + return results[i].score > results[j].score + } + return results[i].uid < results[j].uid + }) + return results +} diff --git a/worker/task.go b/worker/task.go index fbc3189a42b..55697973275 100644 --- a/worker/task.go +++ b/worker/task.go @@ -31,6 +31,7 @@ import ( "github.com/dgraph-io/dgraph/v25/conn" "github.com/dgraph-io/dgraph/v25/posting" "github.com/dgraph-io/dgraph/v25/posting/bm25enc" + // bm25block and bm25wand are used via bm25wand.go in this package. "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" ctask "github.com/dgraph-io/dgraph/v25/task" @@ -1273,7 +1274,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error return nil } - // 3. Read corpus stats from direct Badger KV. + // 3. Read corpus stats. statsKey := x.BM25StatsKey(attr) statsBlob := posting.ReadBM25BlobAt(statsKey, q.ReadTs) docCount, totalTerms := bm25enc.DecodeStats(statsBlob) @@ -1284,7 +1285,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error avgDL := float64(totalTerms) / float64(docCount) N := float64(docCount) - // Build filter set early if used as a filter, for efficient intersection during iteration. + // Build filter set if used as a filter. var filterSet map[uint64]struct{} if q.UidList != nil && len(q.UidList.Uids) > 0 { filterSet = make(map[uint64]struct{}, len(q.UidList.Uids)) @@ -1293,86 +1294,18 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - // 4. For each query token, read the BM25 term blob and collect term info. - type termInfo struct { - idf float64 - uidTFs map[uint64]uint32 + // 4. Determine top-k: use WAND when first is set and no offset. + // When offset is set or first is unset, score all documents. + topK := 0 + if q.First > 0 && q.Offset == 0 { + topK = int(q.First) } - termInfos := make(map[string]*termInfo) - for _, token := range queryTokens { - key := x.BM25IndexKey(attr, token) - blob := posting.ReadBM25BlobAt(key, q.ReadTs) - entries := bm25enc.Decode(blob) - if len(entries) == 0 { - continue - } - - ti := &termInfo{uidTFs: make(map[uint64]uint32)} - df := float64(len(entries)) - for _, e := range entries { - if filterSet != nil { - if _, ok := filterSet[e.UID]; !ok { - continue - } - } - ti.uidTFs[e.UID] = e.Value - } - ti.idf = math.Log1p((N - df + 0.5) / (df + 0.5)) - termInfos[token] = ti - } - - // 5. Read doc lengths for all UIDs seen using binary search on the doclen blob. - allUids := make(map[uint64]struct{}) - for _, ti := range termInfos { - for uid := range ti.uidTFs { - allUids[uid] = struct{}{} - } - } - - dlKey := x.BM25DocLenKey(attr) - dlBlob := posting.ReadBM25BlobAt(dlKey, q.ReadTs) - dlEntries := bm25enc.Decode(dlBlob) - - docLens := make(map[uint64]uint32, len(allUids)) - for uid := range allUids { - if v, ok := bm25enc.Search(dlEntries, uid); ok { - docLens[uid] = v - } - } - - // 6. Compute final BM25 scores. - scores := make(map[uint64]float64) - for _, ti := range termInfos { - for uid, tf := range ti.uidTFs { - dl := float64(1) - if v, ok := docLens[uid]; ok { - dl = float64(v) - } - tfFloat := float64(tf) - score := ti.idf * (k + 1) * tfFloat / (k*(1-b+b*dl/avgDL) + tfFloat) - scores[uid] += score - } - } - - // 7. Sort by score descending. - type uidScore struct { - uid uint64 - score float64 - } - results := make([]uidScore, 0, len(scores)) - for uid, score := range scores { - results = append(results, uidScore{uid: uid, score: score}) - } - sort.Slice(results, func(i, j int) bool { - if results[i].score != results[j].score { - return results[i].score > results[j].score - } - return results[i].uid < results[j].uid - }) + // 5. Run WAND search over block-based posting lists (with Block-Max skipping). + results := wandSearch(attr, q.ReadTs, queryTokens, k, b, avgDL, N, topK, filterSet, true) - // Apply first/offset pagination on score-sorted results before returning UIDs. - if q.First > 0 || q.Offset > 0 { + // 6. Apply first/offset pagination on score-sorted results. + if topK <= 0 && (q.First > 0 || q.Offset > 0) { offset := int(q.Offset) if offset > len(results) { offset = len(results) @@ -1383,7 +1316,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - // Build output: UIDs sorted ascending (required by query pipeline) + // 7. Build output: UIDs sorted ascending (required by query pipeline) // and ValueMatrix with aligned scores (for bm25_score pseudo-predicate). sort.Slice(results, func(i, j int) bool { return results[i].uid < results[j].uid }) uids := make([]uint64, len(results)) @@ -1392,8 +1325,6 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) - // Populate ValueMatrix with BM25 scores aligned to UIDs. - // Each entry is a ValueList with a single float64 value. scoreValues := make([]*pb.ValueList, len(results)) for i, r := range results { buf := make([]byte, 8) From 6373cfef24001a2a7ca225d8b73bb34d5eb5cc6d Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 08:12:59 -0500 Subject: [PATCH 08/12] feat(bm25): add legacy format fallback for migration and WAND unit tests Phase 5 - Migration support: - newListIter falls back to legacy monolithic blob when no block directory exists - lookupDocLen falls back to legacy BM25DocLenKey blob - wandSearch falls back to legacy BM25IndexKey for df computation - Legacy data transparently served through synthetic single-block directory - New writes always use block format; old data works until overwritten Unit tests for WAND components: - TestTopKHeapBasic: heap operations, threshold, eviction - TestTopKHeapTieBreaking: deterministic ordering on score ties - TestBm25ScoreFunction: formula verification, tf/dl/b edge cases - TestBm25ScoreNaN: no NaN/Inf for edge-case inputs Co-Authored-By: Claude Opus 4.6 --- worker/bm25wand.go | 77 +++++++++++++++++++++++++++++---- worker/bm25wand_test.go | 96 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 9 deletions(-) create mode 100644 worker/bm25wand_test.go diff --git a/worker/bm25wand.go b/worker/bm25wand.go index 7fc9dc74ec1..c946ecd9202 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -31,16 +31,55 @@ type listIter struct { inBlockPos int // position within current block exhausted bool + legacy bool // true if using legacy monolithic blob (migration fallback) } // newListIter creates a new iterator for a term's block-based posting list. +// Falls back to the legacy monolithic blob format if no block directory exists. func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64) *listIter { dirKey := x.BM25TermDirKey(attr, encodedTerm) dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) dir := bm25block.DecodeDir(dirBlob) if len(dir.Blocks) == 0 { - return &listIter{exhausted: true} + // Fallback: try reading the legacy monolithic blob and wrap it as a single block. + legacyKey := x.BM25IndexKey(attr, encodedTerm) + legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) + legacyEntries := bm25enc.Decode(legacyBlob) + if len(legacyEntries) == 0 { + return &listIter{exhausted: true} + } + // Build a synthetic single-block directory from the legacy data. + var maxTF uint32 + for _, e := range legacyEntries { + if e.Value > maxTF { + maxTF = e.Value + } + } + dir = &bm25block.Dir{ + NextID: 1, + Blocks: []bm25block.BlockMeta{{ + FirstUID: legacyEntries[0].UID, + BlockID: 0, + Count: uint32(len(legacyEntries)), + MaxTF: maxTF, + }}, + } + it := &listIter{ + attr: attr, + encodedTerm: encodedTerm, + readTs: readTs, + idf: idf, + k: k, + b: b, + dir: dir, + ubPreSuf: bm25block.SuffixMaxUBPre(dir, k, b), + blockIdx: 0, + block: legacyEntries, // pre-loaded + inBlockPos: -1, // will advance on first next() + legacy: true, + } + return it } it := &listIter{ @@ -215,6 +254,11 @@ func (it *listIter) findBlockForTarget(target uint64) int { // loadBlock decodes the block at the given directory index. func (it *listIter) loadBlock(idx int) { + if it.legacy { + // Legacy mode: single block already loaded. + it.inBlockPos = 0 + return + } bm := it.dir.Blocks[idx] blockKey := x.BM25TermBlockKey(it.attr, it.encodedTerm, bm.BlockID) blob := posting.ReadBM25BlobAt(blockKey, it.readTs) @@ -288,13 +332,21 @@ func bm25Score(idf, tf, dl, avgDL, k, b float64) float64 { } // lookupDocLen looks up a single UID's document length from the block-based doclen store. +// Falls back to the legacy monolithic doclen blob if no block directory exists. func lookupDocLen(attr string, uid, readTs uint64) float64 { dirKey := x.BM25DocLenDirKey(attr) dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) dir := bm25block.DecodeDir(dirBlob) if len(dir.Blocks) == 0 { - return 1.0 // fallback + // Fallback: try the legacy monolithic doclen blob. + legacyKey := x.BM25DocLenKey(attr) + legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) + legacyEntries := bm25enc.Decode(legacyBlob) + if v, ok := bm25enc.Search(legacyEntries, uid); ok { + return float64(v) + } + return 1.0 } blockIdx := dir.FindBlock(uid) @@ -317,17 +369,24 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, // Build iterators for each query term. var iters []*listIter for _, token := range queryTokens { + // Compute df: try block directory first, then fall back to legacy blob. + var df uint64 dirKey := x.BM25TermDirKey(attr, token) dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) dir := bm25block.DecodeDir(dirBlob) - if len(dir.Blocks) == 0 { - continue + if len(dir.Blocks) > 0 { + for _, bm := range dir.Blocks { + df += uint64(bm.Count) + } + } else { + // Legacy fallback: read the monolithic blob to get df. + legacyKey := x.BM25IndexKey(attr, token) + legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) + legacyEntries := bm25enc.Decode(legacyBlob) + df = uint64(len(legacyEntries)) } - - // Compute df from directory. - var df uint64 - for _, bm := range dir.Blocks { - df += uint64(bm.Count) + if df == 0 { + continue } idf := math.Log1p((N - float64(df) + 0.5) / (float64(df) + 0.5)) diff --git a/worker/bm25wand_test.go b/worker/bm25wand_test.go new file mode 100644 index 00000000000..5982f94d0b8 --- /dev/null +++ b/worker/bm25wand_test.go @@ -0,0 +1,96 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package worker + +import ( + "container/heap" + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTopKHeapBasic(t *testing.T) { + h := &topKHeap{k: 3} + heap.Init(h) + + require.Equal(t, 0.0, h.threshold()) + + h.tryPush(1, 5.0) + h.tryPush(2, 3.0) + require.Equal(t, 0.0, h.threshold()) // not full yet + + h.tryPush(3, 7.0) + require.InEpsilon(t, 3.0, h.threshold(), 1e-9) // full, min is 3.0 + + h.tryPush(4, 4.0) + require.InEpsilon(t, 4.0, h.threshold(), 1e-9) // 3.0 evicted, min is now 4.0 + + // 2.0 shouldn't be accepted. + h.tryPush(5, 2.0) + require.InEpsilon(t, 4.0, h.threshold(), 1e-9) + + sorted := h.sorted() + require.Len(t, sorted, 3) + require.Equal(t, uint64(3), sorted[0].uid) // highest score (7.0) + require.Equal(t, uint64(1), sorted[1].uid) // 5.0 + require.Equal(t, uint64(4), sorted[2].uid) // 4.0 +} + +func TestTopKHeapTieBreaking(t *testing.T) { + h := &topKHeap{k: 5} + heap.Init(h) + + // Same score, different UIDs — should sort by UID ascending. + h.tryPush(10, 5.0) + h.tryPush(5, 5.0) + h.tryPush(15, 5.0) + + sorted := h.sorted() + require.Equal(t, uint64(5), sorted[0].uid) + require.Equal(t, uint64(10), sorted[1].uid) + require.Equal(t, uint64(15), sorted[2].uid) +} + +func TestBm25ScoreFunction(t *testing.T) { + k, b := 1.2, 0.75 + avgDL := 10.0 + + // idf * (k+1) * tf / (k*(1-b+b*dl/avgDL) + tf) + idf := 1.5 + tf := 3.0 + dl := 10.0 + + expected := idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) + got := bm25Score(idf, tf, dl, avgDL, k, b) + require.InEpsilon(t, expected, got, 1e-9) + + // With b=0: no length normalization. + expected0 := idf * (k + 1) * tf / (k + tf) + got0 := bm25Score(idf, tf, dl, avgDL, k, 0) + require.InEpsilon(t, expected0, got0, 1e-9) + + // Score should be positive for positive inputs. + require.Greater(t, bm25Score(1.0, 1.0, 5.0, 10.0, k, b), 0.0) + + // Higher tf should produce higher score (same dl). + s1 := bm25Score(idf, 1.0, dl, avgDL, k, b) + s3 := bm25Score(idf, 3.0, dl, avgDL, k, b) + require.Greater(t, s3, s1) + + // Shorter doc should score higher (same tf). + sShort := bm25Score(idf, tf, 5.0, avgDL, k, b) + sLong := bm25Score(idf, tf, 20.0, avgDL, k, b) + require.Greater(t, sShort, sLong) +} + +func TestBm25ScoreNaN(t *testing.T) { + // Ensure no NaN/Inf for edge-case inputs. + score := bm25Score(0.5, 1.0, 0.0, 10.0, 1.2, 0.75) + require.False(t, math.IsNaN(score)) + require.False(t, math.IsInf(score, 0)) + require.Greater(t, score, 0.0) +} From f21c20ab6a7b99e08814e82f16c829091d28461c Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 08:26:11 -0500 Subject: [PATCH 09/12] fix(bm25): address GPT-5 code review findings in WAND implementation Fixes critical bugs and performance issues identified by GPT-5 review: - Fix negative inBlockPos panic: guard currentDoc/currentTF/skipTo against inBlockPos < 0 (possible before first next() call) - Fix empty block pathological behavior: next()/skipTo()/skipToWithBMW() now skip empty blocks instead of leaving iterator in invalid state with MaxUint64 pivotDoc - Fix legacy loadBlock: no longer resets inBlockPos to 0 (was moving pointer backwards, could cause re-scoring or infinite loops) - Fix remainingUB panic: guard against blockIdx < 0 (before first next()) - Add docLenCache: caches doclen directory + block reads within a single query, avoiding repeated Badger reads per scored document - Optimize BMW otherUB: compute as sumUB - thisUB (O(1)) instead of iterating all other terms (O(q^2) -> O(q)) Co-Authored-By: Claude Opus 4.6 --- worker/bm25wand.go | 168 ++++++++++++++++++++++++++++++--------------- 1 file changed, 113 insertions(+), 55 deletions(-) diff --git a/worker/bm25wand.go b/worker/bm25wand.go index c946ecd9202..5aab0cd0e44 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -24,11 +24,11 @@ type listIter struct { idf float64 k, b float64 - dir *bm25block.Dir - ubPreSuf []float64 // suffix max of UBPre values - blockIdx int // current block index in dir.Blocks - block []bm25enc.Entry // decoded current block - inBlockPos int // position within current block + dir *bm25block.Dir + ubPreSuf []float64 // suffix max of UBPre values + blockIdx int // current block index in dir.Blocks + block []bm25enc.Entry // decoded current block + inBlockPos int // position within current block exhausted bool legacy bool // true if using legacy monolithic blob (migration fallback) @@ -98,7 +98,7 @@ func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64) *li // currentDoc returns the UID at the current position. func (it *listIter) currentDoc() uint64 { - if it.exhausted || it.block == nil || it.inBlockPos >= len(it.block) { + if it.exhausted || it.block == nil || it.inBlockPos < 0 || it.inBlockPos >= len(it.block) { return math.MaxUint64 } return it.block[it.inBlockPos].UID @@ -106,7 +106,7 @@ func (it *listIter) currentDoc() uint64 { // currentTF returns the term frequency at the current position. func (it *listIter) currentTF() uint32 { - if it.exhausted || it.block == nil || it.inBlockPos >= len(it.block) { + if it.exhausted || it.block == nil || it.inBlockPos < 0 || it.inBlockPos >= len(it.block) { return 0 } return it.block[it.inBlockPos].Value @@ -114,10 +114,17 @@ func (it *listIter) currentTF() uint32 { // remainingUB returns the IDF-weighted upper-bound score for the remaining postings. func (it *listIter) remainingUB() float64 { - if it.exhausted || it.blockIdx >= len(it.ubPreSuf) { + if it.exhausted || len(it.ubPreSuf) == 0 { return 0 } - return it.idf * it.ubPreSuf[it.blockIdx] + idx := it.blockIdx + if idx < 0 { + idx = 0 + } + if idx >= len(it.ubPreSuf) { + return 0 + } + return it.idf * it.ubPreSuf[idx] } // blockUB returns the IDF-weighted upper-bound for the current block only. @@ -137,19 +144,24 @@ func (it *listIter) next() bool { // Try advancing within the current block. if it.block != nil { it.inBlockPos++ - if it.inBlockPos < len(it.block) { + if it.inBlockPos >= 0 && it.inBlockPos < len(it.block) { return true } } // Move to the next block. - it.blockIdx++ - if it.blockIdx >= len(it.dir.Blocks) { - it.exhausted = true - return false + for { + it.blockIdx++ + if it.blockIdx >= len(it.dir.Blocks) { + it.exhausted = true + return false + } + it.loadBlock(it.blockIdx) + if len(it.block) > 0 { + return true + } + // Empty block (corruption/race): skip it. } - it.loadBlock(it.blockIdx) - return it.inBlockPos < len(it.block) } // skipTo advances to the first posting with UID >= target. @@ -160,19 +172,25 @@ func (it *listIter) skipTo(target uint64) bool { } // If current doc is already >= target, no-op. - if it.block != nil && it.inBlockPos < len(it.block) && it.block[it.inBlockPos].UID >= target { + if it.block != nil && it.inBlockPos >= 0 && it.inBlockPos < len(it.block) && + it.block[it.inBlockPos].UID >= target { return true } // Check if target might be in the current block. - if it.block != nil && it.blockIdx < len(it.dir.Blocks) { + if it.block != nil && len(it.block) > 0 && it.blockIdx >= 0 && + it.blockIdx < len(it.dir.Blocks) { lastInBlock := it.block[len(it.block)-1].UID if target <= lastInBlock { - // Binary search within current block. - pos := sort.Search(len(it.block)-it.inBlockPos, func(i int) bool { - return it.block[it.inBlockPos+i].UID >= target + startPos := it.inBlockPos + if startPos < 0 { + startPos = 0 + } + // Binary search within current block from startPos. + pos := sort.Search(len(it.block)-startPos, func(i int) bool { + return it.block[startPos+i].UID >= target }) - it.inBlockPos += pos + it.inBlockPos = startPos + pos if it.inBlockPos < len(it.block) { return true } @@ -188,6 +206,9 @@ func (it *listIter) skipTo(target uint64) bool { it.blockIdx = blockIdx it.loadBlock(blockIdx) + if len(it.block) == 0 { + return it.next() // skip empty block + } // Binary search within the block. pos := sort.Search(len(it.block), func(i int) bool { @@ -209,7 +230,8 @@ func (it *listIter) skipToWithBMW(target uint64, theta float64, otherUB float64) } // If current doc is already >= target, no-op. - if it.block != nil && it.inBlockPos < len(it.block) && it.block[it.inBlockPos].UID >= target { + if it.block != nil && it.inBlockPos >= 0 && it.inBlockPos < len(it.block) && + it.block[it.inBlockPos].UID >= target { return true } @@ -221,6 +243,10 @@ func (it *listIter) skipToWithBMW(target uint64, theta float64, otherUB float64) // This block might have a winner; load and search it. it.blockIdx = blockIdx it.loadBlock(blockIdx) + if len(it.block) == 0 { + blockIdx++ + continue // skip empty block + } pos := sort.Search(len(it.block), func(i int) bool { return it.block[i].UID >= target }) @@ -231,7 +257,7 @@ func (it *listIter) skipToWithBMW(target uint64, theta float64, otherUB float64) // Fall through to next block. } blockIdx++ - // Update target to the next block's firstUID (we've already skipped past target). + // Update target to the next block's firstUID. if blockIdx < len(it.dir.Blocks) { target = it.dir.Blocks[blockIdx].FirstUID } @@ -255,8 +281,7 @@ func (it *listIter) findBlockForTarget(target uint64) int { // loadBlock decodes the block at the given directory index. func (it *listIter) loadBlock(idx int) { if it.legacy { - // Legacy mode: single block already loaded. - it.inBlockPos = 0 + // Legacy mode: single pre-loaded block; don't reset position. return } bm := it.dir.Blocks[idx] @@ -331,29 +356,65 @@ func bm25Score(idf, tf, dl, avgDL, k, b float64) float64 { return idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) } -// lookupDocLen looks up a single UID's document length from the block-based doclen store. -// Falls back to the legacy monolithic doclen blob if no block directory exists. -func lookupDocLen(attr string, uid, readTs uint64) float64 { - dirKey := x.BM25DocLenDirKey(attr) - dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) - dir := bm25block.DecodeDir(dirBlob) +// docLenCache caches document length lookups within a single query to avoid +// repeated Badger reads for the same doclen block directory and blocks. +type docLenCache struct { + attr string + readTs uint64 + dir *bm25block.Dir + loaded bool + legacy bool + // Per-block cache: blockIdx -> decoded entries. + blocks map[int][]bm25enc.Entry + // Legacy entries (when using monolithic blob). + legacyEntries []bm25enc.Entry +} - if len(dir.Blocks) == 0 { - // Fallback: try the legacy monolithic doclen blob. - legacyKey := x.BM25DocLenKey(attr) - legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) - legacyEntries := bm25enc.Decode(legacyBlob) - if v, ok := bm25enc.Search(legacyEntries, uid); ok { +func newDocLenCache(attr string, readTs uint64) *docLenCache { + return &docLenCache{ + attr: attr, + readTs: readTs, + blocks: make(map[int][]bm25enc.Entry), + } +} + +func (c *docLenCache) ensureLoaded() { + if c.loaded { + return + } + c.loaded = true + dirKey := x.BM25DocLenDirKey(c.attr) + dirBlob := posting.ReadBM25BlobAt(dirKey, c.readTs) + c.dir = bm25block.DecodeDir(dirBlob) + if len(c.dir.Blocks) == 0 { + // Try legacy. + legacyKey := x.BM25DocLenKey(c.attr) + legacyBlob := posting.ReadBM25BlobAt(legacyKey, c.readTs) + c.legacyEntries = bm25enc.Decode(legacyBlob) + c.legacy = true + } +} + +func (c *docLenCache) lookup(uid uint64) float64 { + c.ensureLoaded() + if c.legacy { + if v, ok := bm25enc.Search(c.legacyEntries, uid); ok { return float64(v) } return 1.0 } - - blockIdx := dir.FindBlock(uid) - bm := dir.Blocks[blockIdx] - blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) - blob := posting.ReadBM25BlobAt(blockKey, readTs) - entries := bm25enc.Decode(blob) + if len(c.dir.Blocks) == 0 { + return 1.0 + } + blockIdx := c.dir.FindBlock(uid) + entries, ok := c.blocks[blockIdx] + if !ok { + bm := c.dir.Blocks[blockIdx] + blockKey := x.BM25DocLenBlockKey(c.attr, bm.BlockID) + blob := posting.ReadBM25BlobAt(blockKey, c.readTs) + entries = bm25enc.Decode(blob) + c.blocks[blockIdx] = entries + } if v, ok := bm25enc.Search(entries, uid); ok { return float64(v) } @@ -366,6 +427,8 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, k, b, avgDL, N float64, topK int, filterSet map[uint64]struct{}, useBMW bool) []scoredDoc { + dlCache := newDocLenCache(attr, readTs) + // Build iterators for each query term. var iters []*listIter for _, token := range queryTokens { @@ -405,7 +468,7 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, // If no top-k limit, score all matching documents. if topK <= 0 { - return scoreAllDocs(iters, attr, readTs, k, b, avgDL, filterSet) + return scoreAllDocs(iters, dlCache, k, b, avgDL, filterSet) } // WAND algorithm with top-k heap. @@ -454,13 +517,8 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, if iters[i].currentDoc() < pivotDoc { var ok bool if useBMW { - // Compute other UBs for BMW skipping. - var otherUB float64 - for j, jt := range iters { - if j != i { - otherUB += jt.remainingUB() - } - } + // Compute otherUB = total UB - this iter's UB (O(1) instead of O(q)). + otherUB := sumUB - iters[i].remainingUB() ok = iters[i].skipToWithBMW(pivotDoc, theta, otherUB) } else { ok = iters[i].skipTo(pivotDoc) @@ -492,7 +550,7 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, } } - dl := lookupDocLen(attr, pivotDoc, readTs) + dl := dlCache.lookup(pivotDoc) var score float64 for _, it := range iters { if it.currentDoc() == pivotDoc { @@ -515,7 +573,7 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, // scoreAllDocs scores every matching document without early termination. // Used when no top-k limit is specified (the original behavior). -func scoreAllDocs(iters []*listIter, attr string, readTs uint64, +func scoreAllDocs(iters []*listIter, dlCache *docLenCache, k, b, avgDL float64, filterSet map[uint64]struct{}) []scoredDoc { // Collect all (uid, term) matches. @@ -541,7 +599,7 @@ func scoreAllDocs(iters []*listIter, attr string, readTs uint64, // Score all matching documents. results := make([]scoredDoc, 0, len(matches)) for uid, terms := range matches { - dl := lookupDocLen(attr, uid, readTs) + dl := dlCache.lookup(uid) var score float64 for _, tm := range terms { score += bm25Score(tm.idf, float64(tm.tf), dl, avgDL, k, b) From 3503a9c56162a16eb449490aa2f093978ebbd096 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 08:53:46 -0500 Subject: [PATCH 10/12] fix(bm25): prevent stats double-counting on updates and fix BMW otherUB underestimate Three fixes: 1. CRITICAL: addBM25IndexMutations now checks if a UID already exists in doclen blocks before incrementing stats, preventing double-counting on SET when the document was already indexed (defensive guard for batch mutations). 2. HIGH: WAND sumUB now accumulates across ALL iterators (not just up to pivot), so BMW's otherUB calculation is correct and won't skip valid candidate blocks. 3. PERF: newListIter accepts pre-read Dir to eliminate duplicate Badger reads (directory was read once for df, then again inside newListIter). Co-Authored-By: Claude Opus 4.6 --- posting/index.go | 38 ++++++++++++++++++++++++++++++++++++-- worker/bm25wand.go | 17 ++++++++++------- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/posting/index.go b/posting/index.go index d2eadd904e0..e19deae8d73 100644 --- a/posting/index.go +++ b/posting/index.go @@ -272,13 +272,26 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn return txn.updateBM25Stats(attr, -1, -int64(docLen)) } - // For SET: upsert term frequencies and doc length into blocks. + // For SET: check if this UID already has a doclen entry (i.e., this is an update). + // If so, subtract old stats to avoid double-counting. + oldDocLen, isUpdate := txn.bm25DocLenBlockLookup(attr, uid) + for term, tf := range termFreqs { encodedTerm := string([]byte{tok.IdentBM25}) + term txn.bm25BlockUpsert(attr, encodedTerm, uid, tf) } txn.bm25DocLenBlockUpsert(attr, uid, docLen) - return txn.updateBM25Stats(attr, 1, int64(docLen)) + + var docCountDelta int64 + var totalTermsDelta int64 + if isUpdate { + // Document already existed: don't increment docCount, adjust totalTerms by diff. + totalTermsDelta = int64(docLen) - int64(oldDocLen) + } else { + docCountDelta = 1 + totalTermsDelta = int64(docLen) + } + return txn.updateBM25Stats(attr, docCountDelta, totalTermsDelta) } // bm25BlockUpsert inserts or updates a (uid, value) entry in the block-based @@ -400,6 +413,27 @@ func (txn *Txn) bm25DocLenBlockUpsert(attr string, uid uint64, docLen uint32) { txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) } +// bm25DocLenBlockLookup checks if a uid exists in the doclen blocks and returns its value. +func (txn *Txn) bm25DocLenBlockLookup(attr string, uid uint64) (uint32, bool) { + dirKey := x.BM25DocLenDirKey(attr) + dirBlob := txn.cache.ReadBM25Blob(dirKey) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + return 0, false + } + + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) + blob := txn.cache.ReadBM25Blob(blockKey) + entries := bm25enc.Decode(blob) + if v, ok := bm25enc.Search(entries, uid); ok { + return v, true + } + return 0, false +} + // bm25DocLenBlockRemove removes a uid from the block-based document-length list. func (txn *Txn) bm25DocLenBlockRemove(attr string, uid uint64) { dirKey := x.BM25DocLenDirKey(attr) diff --git a/worker/bm25wand.go b/worker/bm25wand.go index 5aab0cd0e44..de5ecfb2b3c 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -36,10 +36,13 @@ type listIter struct { // newListIter creates a new iterator for a term's block-based posting list. // Falls back to the legacy monolithic blob format if no block directory exists. -func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64) *listIter { - dirKey := x.BM25TermDirKey(attr, encodedTerm) - dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) - dir := bm25block.DecodeDir(dirBlob) +// If dir is non-nil, it is used directly (avoids re-reading from Badger). +func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64, dir *bm25block.Dir) *listIter { + if dir == nil { + dirKey := x.BM25TermDirKey(attr, encodedTerm) + dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) + dir = bm25block.DecodeDir(dirBlob) + } if len(dir.Blocks) == 0 { // Fallback: try reading the legacy monolithic blob and wrap it as a single block. @@ -453,7 +456,7 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, } idf := math.Log1p((N - float64(df) + 0.5) / (float64(df) + 0.5)) - it := newListIter(attr, token, readTs, idf, k, b) + it := newListIter(attr, token, readTs, idf, k, b, dir) if !it.exhausted { it.next() // prime the iterator if !it.exhausted { @@ -501,12 +504,12 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, var pivotDoc uint64 for i, it := range iters { sumUB += it.remainingUB() - if sumUB > theta { + if sumUB > theta && pivot == -1 { pivot = i pivotDoc = it.currentDoc() - break } } + // sumUB now contains the total UB across ALL iterators (needed for BMW). if pivot == -1 { break // sum of all UBs can't beat theta } From 9093cb001375ad01f8844545df2f2ba8c5607a42 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 09:10:29 -0500 Subject: [PATCH 11/12] fix(bm25): clamp startPos in skipTo to prevent negative sort.Search length Defensive hardening from GPT-5 review: if inBlockPos exceeds block length after next() reaches end of block, the sort.Search span could go negative. Co-Authored-By: Claude Opus 4.6 --- worker/bm25wand.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/worker/bm25wand.go b/worker/bm25wand.go index de5ecfb2b3c..4ae2569fa7a 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -188,6 +188,8 @@ func (it *listIter) skipTo(target uint64) bool { startPos := it.inBlockPos if startPos < 0 { startPos = 0 + } else if startPos > len(it.block) { + startPos = len(it.block) } // Binary search within current block from startPos. pos := sort.Search(len(it.block)-startPos, func(i int) bool { From 31e70e62397ac783a488157a2512ba027a16d452 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 18 Mar 2026 21:23:16 -0400 Subject: [PATCH 12/12] fix(bm25): address Gemini/GPT-5 code review findings - Add DecodeCount() to bm25enc for O(1) entry count reads without full decode, preventing OOM on legacy migration with large posting lists (e.g., common terms with millions of entries) - Use DecodeCount in WAND search legacy DF calculation path - Fix integer overflow in DecodeDir bounds check by using uint64 arithmetic (prevents panic on corrupted data with MaxUint32 count) - Pre-allocate shared score buffer in handleBM25Search with three-index slices to prevent accidental append corruption - Document bm25Writes concurrency model and limitations Co-Authored-By: Claude Opus 4.6 (1M context) --- posting/bm25block/bm25block.go | 3 +- posting/bm25block/bm25block_test.go | 13 ++++ posting/bm25enc/bm25enc.go | 11 +++ posting/bm25enc/bm25enc_test.go | 31 ++++++++ posting/lists.go | 12 +++ query/query_bm25_test.go | 115 ++++++++++++++++++---------- worker/bm25wand.go | 6 +- worker/task.go | 14 +++- 8 files changed, 159 insertions(+), 46 deletions(-) diff --git a/posting/bm25block/bm25block.go b/posting/bm25block/bm25block.go index f529ed8fab8..e9c4fa1776e 100644 --- a/posting/bm25block/bm25block.go +++ b/posting/bm25block/bm25block.go @@ -77,7 +77,8 @@ func DecodeDir(data []byte) *Dir { } count := binary.BigEndian.Uint32(data[0:4]) nextID := binary.BigEndian.Uint32(data[4:8]) - if int(count)*dirEntrySize+dirHeaderSize > len(data) { + // Use uint64 arithmetic to prevent integer overflow on corrupted data. + if uint64(count)*dirEntrySize+dirHeaderSize > uint64(len(data)) { return &Dir{NextID: nextID} } blocks := make([]BlockMeta, count) diff --git a/posting/bm25block/bm25block_test.go b/posting/bm25block/bm25block_test.go index a7cc26f493a..f9cccb7554a 100644 --- a/posting/bm25block/bm25block_test.go +++ b/posting/bm25block/bm25block_test.go @@ -6,6 +6,7 @@ package bm25block import ( + "encoding/binary" "math" "testing" @@ -39,6 +40,18 @@ func TestDirRoundtripEmpty(t *testing.T) { require.Empty(t, got.Blocks) } +func TestDecodeDirCorruptedLargeCount(t *testing.T) { + // A corrupted blob with a massive count should not panic due to integer overflow. + // count = MaxUint32, nextID = 0, followed by only 8 bytes of data. + data := make([]byte, 16) + binary.BigEndian.PutUint32(data[0:4], 0xFFFFFFFF) // count = MaxUint32 + binary.BigEndian.PutUint32(data[4:8], 0) // nextID = 0 + got := DecodeDir(data) + // Should return an empty Dir (with nextID preserved) rather than panicking. + require.Empty(t, got.Blocks) + require.Equal(t, uint32(0), got.NextID) +} + func TestDirRoundtripSingle(t *testing.T) { dir := &Dir{ NextID: 1, diff --git a/posting/bm25enc/bm25enc.go b/posting/bm25enc/bm25enc.go index 8da82b299dd..86bfe5f5bb1 100644 --- a/posting/bm25enc/bm25enc.go +++ b/posting/bm25enc/bm25enc.go @@ -130,6 +130,17 @@ func UIDs(entries []Entry) []uint64 { return uids } +// DecodeCount reads just the entry count from the header of an encoded blob +// without decoding any entries. This is O(1) and avoids allocating a full +// []Entry slice, which matters for large posting lists (e.g., common terms +// during legacy format migration). +func DecodeCount(data []byte) uint32 { + if len(data) < 4 { + return 0 + } + return binary.BigEndian.Uint32(data[:4]) +} + // EncodeStats encodes BM25 corpus statistics (docCount, totalTerms) as 16 bytes. func EncodeStats(docCount, totalTerms uint64) []byte { buf := make([]byte, 16) diff --git a/posting/bm25enc/bm25enc_test.go b/posting/bm25enc/bm25enc_test.go index 1969e472ed2..f4cfec6bf62 100644 --- a/posting/bm25enc/bm25enc_test.go +++ b/posting/bm25enc/bm25enc_test.go @@ -92,6 +92,37 @@ func TestUIDs(t *testing.T) { require.Equal(t, []uint64{1, 5, 100}, UIDs(entries)) } +func TestDecodeCount(t *testing.T) { + // Normal case: count matches actual entries. + entries := []Entry{ + {UID: 1, Value: 3}, + {UID: 5, Value: 1}, + {UID: 100, Value: 7}, + } + data := Encode(entries) + require.Equal(t, uint32(3), DecodeCount(data)) + + // Empty/nil input. + require.Equal(t, uint32(0), DecodeCount(nil)) + require.Equal(t, uint32(0), DecodeCount([]byte{})) + require.Equal(t, uint32(0), DecodeCount([]byte{1, 2, 3})) + + // Zero count. + require.Equal(t, uint32(0), DecodeCount([]byte{0, 0, 0, 0})) + + // Single entry. + single := Encode([]Entry{{UID: 42, Value: 10}}) + require.Equal(t, uint32(1), DecodeCount(single)) + + // Large count. + large := make([]Entry, 10000) + for i := range large { + large[i] = Entry{UID: uint64(i*3 + 1), Value: uint32(i % 100)} + } + data = Encode(large) + require.Equal(t, uint32(10000), DecodeCount(data)) +} + func TestStatsRoundtrip(t *testing.T) { data := EncodeStats(12345, 98765) dc, tt := DecodeStats(data) diff --git a/posting/lists.go b/posting/lists.go index 0bd9848de23..22d20a53973 100644 --- a/posting/lists.go +++ b/posting/lists.go @@ -79,6 +79,18 @@ type LocalCache struct { // bm25Writes buffers BM25 direct KV writes (key → encoded blob). // These bypass the posting list infrastructure entirely. + // + // CONCURRENCY NOTE: BM25 blocks use full-value overwrites rather than + // posting list deltas. Within a single Dgraph transaction this is safe + // (each Txn has its own LocalCache). Across concurrent transactions, + // Dgraph's Raft-based mutation serialization prevents lost updates for + // the same predicate+UID pair. However, two transactions updating + // different UIDs that share a common term could theoretically race on + // the same term block. In practice this is mitigated by: + // 1. Dgraph serializes mutations through Raft proposals + // 2. Block splits keep contention surface small + // If higher write concurrency is needed, blocks should be integrated + // into the posting list delta mechanism. bm25Writes map[string][]byte } diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index 457c7b46452..1411ad3916e 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -19,6 +19,23 @@ import ( "github.com/stretchr/testify/require" ) +// uidHex queries Dgraph for the hex UID string of a given decimal UID. +// This avoids hardcoding hex values that depend on UID assignment order. +func uidHex(t *testing.T, decimalUID int) string { + t.Helper() + js := processQueryNoErr(t, fmt.Sprintf(`{ me(func: uid(%d)) { uid } }`, decimalUID)) + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.NotEmpty(t, resp.Data.Me, "UID %d should exist", decimalUID) + return resp.Data.Me[0].UID +} + func TestBM25Basic(t *testing.T) { query := ` { @@ -376,9 +393,9 @@ func TestBM25IncrementalAddBatch(t *testing.T) { js = processQueryNoErr(t, countQuery) require.Contains(t, js, `"count":8`) - // Verify specific new UIDs are searchable. - js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "whiskey")) { uid } }`) - require.Contains(t, js, `"0x25e"`) // 606 + // Verify specific new terms are searchable. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "whiskey")) { uid description_bm25 } }`) + require.Contains(t, js, "whiskey") } func TestBM25CorpusStatsAffectIDF(t *testing.T) { @@ -417,7 +434,7 @@ func TestBM25CorpusStatsAffectIDF(t *testing.T) { scoresAfter := parseScoresFromJSON(t, jsAfter) // Compare score for UID 503 ("fox fox fox") — should increase. - uid503 := "0x1f7" + uid503 := uidHex(t, 503) before, ok1 := scoresBefore[uid503] after, ok2 := scoresAfter[uid503] require.True(t, ok1 && ok2, "UID 503 should appear in both before and after results") @@ -432,6 +449,8 @@ func TestBM25DocumentUpdate(t *testing.T) { deleteTriplesInCluster(`<620> * .`) }) + uid620 := uidHex(t, 620) + // Should rank top for "fox". js := processQueryNoErr(t, ` { @@ -439,7 +458,7 @@ func TestBM25DocumentUpdate(t *testing.T) { uid } }`) - require.Contains(t, js, `"0x26c"`) // 620 + require.Contains(t, js, `"`+uid620+`"`) // Update to remove "fox", add "cat". deleteTriplesInCluster(`<620> "fox fox fox fox" .`) @@ -452,7 +471,7 @@ func TestBM25DocumentUpdate(t *testing.T) { uid } }`) - require.NotContains(t, js, `"0x26c"`) + require.NotContains(t, js, `"`+uid620+`"`) // Should appear in "cat" results. js = processQueryNoErr(t, ` @@ -461,7 +480,7 @@ func TestBM25DocumentUpdate(t *testing.T) { uid } }`) - require.Contains(t, js, `"0x26c"`) + require.Contains(t, js, `"`+uid620+`"`) } func TestBM25DocumentDeletion(t *testing.T) { @@ -471,9 +490,11 @@ func TestBM25DocumentDeletion(t *testing.T) { deleteTriplesInCluster(`<625> * .`) }) + uid625 := uidHex(t, 625) + // Should find the elephant doc. js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) - require.Contains(t, js, `"0x271"`) // 625 + require.Contains(t, js, `"`+uid625+`"`) // Delete it. deleteTriplesInCluster(`<625> "unique elephant term" .`) @@ -483,7 +504,7 @@ func TestBM25DocumentDeletion(t *testing.T) { require.JSONEq(t, `{"data": {"me":[]}}`, js) // Baseline "fox" results should be unaffected. - js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "fox")) { uid } }`) + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "fox")) { uid description_bm25 } }`) require.Contains(t, js, "fox") } @@ -499,7 +520,7 @@ func TestBM25ScoreStabilityAsCorpusGrows(t *testing.T) { } } ` - uid503 := "0x1f7" + uid503 := uidHex(t, 503) // Phase 1: baseline score. js1 := processQueryNoErr(t, scoreQuery) @@ -642,8 +663,8 @@ func TestBM25EdgeCaseLongDocument(t *testing.T) { js := processQueryNoErr(t, scoreQuery) scores := parseScoresFromJSON(t, js) - uid503 := "0x1f7" // "fox fox fox" (doclen=3) - uid645 := "0x285" // long doc (doclen~500) + uid503 := uidHex(t, 503) // "fox fox fox" (doclen=3) + uid645 := uidHex(t, 645) // long doc (doclen~500) s503, ok1 := scores[uid503] s645, ok2 := scores[uid645] require.True(t, ok1, "UID 503 must appear in fox results") @@ -667,17 +688,21 @@ func TestBM25EdgeCaseUnicode(t *testing.T) { `) }) + uid650 := uidHex(t, 650) + uid651 := uidHex(t, 651) + uid652 := uidHex(t, 652) + // Query German term. js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "Fuchs")) { uid } }`) - require.Contains(t, js, `"0x28a"`) // 650 + require.Contains(t, js, `"`+uid650+`"`) // Query French term. js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "renard")) { uid } }`) - require.Contains(t, js, `"0x28b"`) // 651 + require.Contains(t, js, `"`+uid651+`"`) // Query Spanish term. js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "zorro")) { uid } }`) - require.Contains(t, js, `"0x28c"`) // 652 + require.Contains(t, js, `"`+uid652+`"`) } func TestBM25EdgeCaseAllStopwordsDoc(t *testing.T) { @@ -686,9 +711,11 @@ func TestBM25EdgeCaseAllStopwordsDoc(t *testing.T) { deleteTriplesInCluster(`<655> * .`) }) + uid655 := uidHex(t, 655) + // Query "the" — should return empty since "the" is a stopword. js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "the")) { uid } }`) - require.NotContains(t, js, `"0x28f"`) // 655 should not appear + require.NotContains(t, js, `"`+uid655+`"`) // 655 should not appear // But the doc should exist via has(). js = processQueryNoErr(t, ` @@ -697,7 +724,7 @@ func TestBM25EdgeCaseAllStopwordsDoc(t *testing.T) { uid } }`) - require.Contains(t, js, `"0x28f"`) + require.Contains(t, js, `"`+uid655+`"`) } func TestBM25WithUidFilter(t *testing.T) { @@ -711,12 +738,16 @@ func TestBM25WithUidFilter(t *testing.T) { } ` js := processQueryNoErr(t, query) + uid501 := uidHex(t, 501) + uid502 := uidHex(t, 502) + uid503 := uidHex(t, 503) + uid506 := uidHex(t, 506) // Should contain only UIDs 501 and 503. - require.Contains(t, js, `"0x1f5"`) // 501 - require.Contains(t, js, `"0x1f7"`) // 503 - // Should NOT contain other fox docs like 502, 506, 507. - require.NotContains(t, js, `"0x1f6"`) // 502 - require.NotContains(t, js, `"0x1fa"`) // 506 + require.Contains(t, js, `"`+uid501+`"`) + require.Contains(t, js, `"`+uid503+`"`) + // Should NOT contain other fox docs like 502, 506. + require.NotContains(t, js, `"`+uid502+`"`) + require.NotContains(t, js, `"`+uid506+`"`) } func TestBM25ScoreValuesAreValidFloats(t *testing.T) { @@ -770,22 +801,23 @@ func TestBM25IncrementalAddThenDeleteThenReadd(t *testing.T) { // Phase 1: add with "elephant". require.NoError(t, addTriplesToCluster(`<670> "elephant roams the savanna" .`)) + uid670 := uidHex(t, 670) js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) - require.Contains(t, js, `"0x29e"`) // 670 + require.Contains(t, js, `"`+uid670+`"`) // Phase 2: delete. deleteTriplesInCluster(`<670> "elephant roams the savanna" .`) js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) - require.NotContains(t, js, `"0x29e"`) + require.NotContains(t, js, `"`+uid670+`"`) // Phase 3: re-add with different content. require.NoError(t, addTriplesToCluster(`<670> "penguin waddles on the ice" .`)) js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "penguin")) { uid } }`) - require.Contains(t, js, `"0x29e"`) + require.Contains(t, js, `"`+uid670+`"`) // "elephant" should still not match 670. js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) - require.NotContains(t, js, `"0x29e"`) + require.NotContains(t, js, `"`+uid670+`"`) } func TestBM25NonIndexedPredicateError(t *testing.T) { @@ -828,11 +860,11 @@ func TestBM25ConcurrentBatchAdd(t *testing.T) { // Spot-check a doc from each batch. for batch := 0; batch < 5; batch++ { - uid := 680 + batch*4 - hexUID := fmt.Sprintf(`"0x%x"`, uid) + decUID := 680 + batch*4 + hexUID := uidHex(t, decUID) term := fmt.Sprintf("batch%d", batch) js = processQueryNoErr(t, fmt.Sprintf(`{ me(func: bm25(description_bm25, "%s")) { uid } }`, term)) - require.Contains(t, js, hexUID, "doc %d from batch %d should be searchable", uid, batch) + require.Contains(t, js, `"`+hexUID+`"`, "doc %d from batch %d should be searchable", decUID, batch) } } @@ -895,10 +927,12 @@ func TestBM25ExactScoreValues(t *testing.T) { // Doc 851 "quasar nebula pulsar": tf=1, b=0 → score = idf * 2.2 * 1 / 2.2 = idf expected851 := idf * (k + 1) * 1.0 / (k + 1.0) - actual850, ok := scores["0x352"] // 850 - require.True(t, ok, "UID 850 (0x352) must be in results") - actual851, ok := scores["0x353"] // 851 - require.True(t, ok, "UID 851 (0x353) must be in results") + uid850 := uidHex(t, 850) + uid851 := uidHex(t, 851) + actual850, ok := scores[uid850] + require.True(t, ok, "UID 850 (%s) must be in results", uid850) + actual851, ok := scores[uid851] + require.True(t, ok, "UID 851 (%s) must be in results", uid851) require.InEpsilon(t, expected850, actual850, 1e-6, "Doc 850 score mismatch: expected %f, got %f (N=%f, df=%f, idf=%f)", @@ -940,8 +974,10 @@ func TestBM25BM15NoLengthNormalization(t *testing.T) { js := processQueryNoErr(t, scoreQuery) scores := parseScoresFromJSON(t, js) - score860, ok1 := scores["0x35c"] // 860 - score861, ok2 := scores["0x35d"] // 861 + uid860 := uidHex(t, 860) + uid861 := uidHex(t, 861) + score860, ok1 := scores[uid860] + score861, ok2 := scores[uid861] require.True(t, ok1, "UID 860 must be in results") require.True(t, ok2, "UID 861 must be in results") @@ -964,8 +1000,8 @@ func TestBM25BM15NoLengthNormalization(t *testing.T) { js = processQueryNoErr(t, scoreQueryDefault) scoresDefault := parseScoresFromJSON(t, js) - defScore860, ok1 := scoresDefault["0x35c"] - defScore861, ok2 := scoresDefault["0x35d"] + defScore860, ok1 := scoresDefault[uid860] + defScore861, ok2 := scoresDefault[uid861] require.True(t, ok1, "UID 860 must be in default results") require.True(t, ok2, "UID 861 must be in default results") require.Greater(t, defScore860, defScore861, @@ -999,8 +1035,9 @@ func TestBM25SingleMatchingDocument(t *testing.T) { require.Len(t, scores, 1, "exactly one document should match 'aardvark'") - actual, ok := scores["0x361"] // 865 - require.True(t, ok, "UID 865 (0x361) must be in results") + uid865 := uidHex(t, 865) + actual, ok := scores[uid865] + require.True(t, ok, "UID 865 (%s) must be in results", uid865) // With df=1, tf=1, b=0, k=1.2: // idf = log1p((N - 1 + 0.5) / (1 + 0.5)) = log1p((N - 0.5) / 1.5) diff --git a/worker/bm25wand.go b/worker/bm25wand.go index 4ae2569fa7a..07988c845df 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -447,11 +447,11 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, df += uint64(bm.Count) } } else { - // Legacy fallback: read the monolithic blob to get df. + // Legacy fallback: read just the count header to get df. + // Avoids decoding the full posting list (which could be huge for common terms). legacyKey := x.BM25IndexKey(attr, token) legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) - legacyEntries := bm25enc.Decode(legacyBlob) - df = uint64(len(legacyEntries)) + df = uint64(bm25enc.DecodeCount(legacyBlob)) } if df == 0 { continue diff --git a/worker/task.go b/worker/task.go index 55697973275..0345e9e75f8 100644 --- a/worker/task.go +++ b/worker/task.go @@ -1318,6 +1318,8 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error // 7. Build output: UIDs sorted ascending (required by query pipeline) // and ValueMatrix with aligned scores (for bm25_score pseudo-predicate). + // We use a single pre-allocated buffer for all score encodings to reduce + // per-result heap allocations. sort.Slice(results, func(i, j int) bool { return results[i].uid < results[j].uid }) uids := make([]uint64, len(results)) for i, r := range results { @@ -1325,12 +1327,18 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) + // Encode scores into ValueMatrix. Each entry in ValueMatrix corresponds + // positionally to a UID in UidMatrix[0], enabling the bm25_score + // pseudo-predicate in query.go to map UIDs to scores. + scoreBuf := make([]byte, len(results)*8) scoreValues := make([]*pb.ValueList, len(results)) for i, r := range results { - buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, math.Float64bits(r.score)) + off := i * 8 + binary.LittleEndian.PutUint64(scoreBuf[off:off+8], math.Float64bits(r.score)) + // Use three-index slice to cap capacity at 8, preventing any downstream + // append from corrupting adjacent scores in the shared backing array. scoreValues[i] = &pb.ValueList{ - Values: []*pb.TaskValue{{Val: buf, ValType: pb.Posting_ValType(pb.Posting_FLOAT)}}, + Values: []*pb.TaskValue{{Val: scoreBuf[off : off+8 : off+8], ValType: pb.Posting_ValType(pb.Posting_FLOAT)}}, } } args.out.ValueMatrix = append(args.out.ValueMatrix, scoreValues...)