diff --git a/README.md b/README.md index 4f330c5..4844b33 100644 --- a/README.md +++ b/README.md @@ -644,6 +644,8 @@ Triple-stream retrieval combining three signals: Fused with Reciprocal Rank Fusion (RRF, k=60) and session-diversified (max 3 results per session). +BM25 tokenizes Greek, Cyrillic, Hebrew, Arabic, and accented Latin out of the box. For Chinese / Japanese / Korean memories, install the optional segmenters (`npm install @node-rs/jieba tiny-segmenter`) to split CJK runs into word-level tokens; without them, agentmemory soft-falls to whole-run tokenization and prints a one-time hint on stderr. + ### Embedding providers agentmemory auto-detects your provider. For best results, install local embeddings (free): diff --git a/package.json b/package.json index 5526fbe..0342f6d 100644 --- a/package.json +++ b/package.json @@ -62,9 +62,11 @@ "zod": "^4.0.0" }, "optionalDependencies": { + "@node-rs/jieba": "^2.0.1", "@xenova/transformers": "^2.17.2", "onnxruntime-node": "^1.14.0", - "onnxruntime-web": "^1.14.0" + "onnxruntime-web": "^1.14.0", + "tiny-segmenter": "^0.2.0" }, "devDependencies": { "@types/node": "^22.0.0", diff --git a/src/state/cjk-segmenter.ts b/src/state/cjk-segmenter.ts new file mode 100644 index 0000000..d6d4a19 --- /dev/null +++ b/src/state/cjk-segmenter.ts @@ -0,0 +1,169 @@ +import { createRequire } from "node:module"; + +const cjkRequire = createRequire(import.meta.url); + +const CJK_RE = /[\p{Script=Han}\p{Script=Hiragana}\p{Script=Katakana}\p{Script=Hangul}]/u; +const HAN_RE = /\p{Script=Han}/u; +const KANA_RE = /[\p{Script=Hiragana}\p{Script=Katakana}]/u; +const HANGUL_RE = /\p{Script=Hangul}/u; +const CJK_RUN_RE = /[\p{Script=Han}\p{Script=Hiragana}\p{Script=Katakana}\p{Script=Hangul}]+/gu; +const HANGUL_BLOCK_RE = /[가-힯]+/g; + +type Script = "han" | "kana" | "hangul" | "other"; + +const hintShown = new Set(); + +export function hasCjk(text: string): boolean { + return CJK_RE.test(text); +} + +export function detectScript(text: string): Script { + if (HAN_RE.test(text)) return "han"; + if (KANA_RE.test(text)) return "kana"; + if (HANGUL_RE.test(text)) return "hangul"; + return "other"; +} + +function showHintOnce(key: string, message: string): void { + if (hintShown.has(key)) return; + hintShown.add(key); + if (typeof process !== "undefined" && process.stderr?.write) { + process.stderr.write(`agentmemory: ${message}\n`); + } +} + +interface JiebaInstance { + cut(text: string, hmm?: boolean): string[]; +} + +let jiebaInstance: JiebaInstance | null = null; +let jiebaLoaded = false; + +function getJieba(): JiebaInstance | null { + if (jiebaLoaded) return jiebaInstance; + jiebaLoaded = true; + try { + const mod = cjkRequire("@node-rs/jieba") as { + Jieba: { + new (): JiebaInstance; + withDict(dict: Uint8Array): JiebaInstance; + }; + }; + try { + const dictMod = cjkRequire("@node-rs/jieba/dict") as { dict: Uint8Array }; + jiebaInstance = mod.Jieba.withDict(dictMod.dict); + } catch { + jiebaInstance = new mod.Jieba(); + } + return jiebaInstance; + } catch { + showHintOnce( + "jieba", + "install @node-rs/jieba to improve Chinese search; falling back to whole-string tokenization", + ); + return null; + } +} + +interface JaSegmenter { + segment(text: string): string[]; +} + +let jaSegmenterInstance: JaSegmenter | null = null; +let jaSegmenterLoaded = false; + +function getJaSegmenter(): JaSegmenter | null { + if (jaSegmenterLoaded) return jaSegmenterInstance; + jaSegmenterLoaded = true; + try { + const Ctor = cjkRequire("tiny-segmenter") as new () => JaSegmenter; + jaSegmenterInstance = new Ctor(); + return jaSegmenterInstance; + } catch { + showHintOnce( + "tiny-segmenter", + "install tiny-segmenter to improve Japanese search; falling back to whole-string tokenization", + ); + return null; + } +} + +function cleanTokens(tokens: string[]): string[] { + const out: string[] = []; + for (const t of tokens) { + const trimmed = t.trim(); + if (trimmed) out.push(trimmed); + } + return out; +} + +function segmentHan(text: string): string[] { + const j = getJieba(); + if (!j) return [text]; + try { + return cleanTokens(j.cut(text, true)); + } catch { + return [text]; + } +} + +function segmentKana(text: string): string[] { + const s = getJaSegmenter(); + if (!s) return [text]; + try { + return cleanTokens(s.segment(text)); + } catch { + return [text]; + } +} + +function segmentHangul(text: string): string[] { + const out: string[] = []; + for (const m of text.matchAll(HANGUL_BLOCK_RE)) { + if (m[0]) out.push(m[0]); + } + return out; +} + +export function segmentCjk(text: string): string[] { + if (!hasCjk(text)) return [text]; + + const out: string[] = []; + let cursor = 0; + + for (const m of text.matchAll(CJK_RUN_RE)) { + const start = m.index ?? 0; + const run = m[0]; + const end = start + run.length; + + if (start > cursor) { + const piece = text.slice(cursor, start).trim(); + if (piece) out.push(piece); + } + + if (HANGUL_RE.test(run)) { + out.push(...segmentHangul(run)); + } else if (KANA_RE.test(run)) { + out.push(...segmentKana(run)); + } else { + out.push(...segmentHan(run)); + } + + cursor = end; + } + + if (cursor < text.length) { + const trailing = text.slice(cursor).trim(); + if (trailing) out.push(trailing); + } + + return out; +} + +export function __resetCjkSegmenterStateForTests(): void { + hintShown.clear(); + jiebaInstance = null; + jiebaLoaded = false; + jaSegmenterInstance = null; + jaSegmenterLoaded = false; +} diff --git a/src/state/search-index.ts b/src/state/search-index.ts index d989234..f253007 100644 --- a/src/state/search-index.ts +++ b/src/state/search-index.ts @@ -1,6 +1,7 @@ import type { CompressedObservation } from "../types.js"; import { stem } from "./stemmer.js"; import { getSynonyms } from "./synonyms.js"; +import { segmentCjk, hasCjk } from "./cjk-segmenter.js"; interface IndexEntry { obsId: string; @@ -222,11 +223,19 @@ export class SearchIndex { } private tokenize(text: string): string[] { - return text - .replace(/[^\p{L}\p{N}\s/.\\-_]/gu, " ") - .split(/\s+/) - .filter((t) => t.length > 1) - .map((t) => stem(t)); + const cleaned = text.replace(/[^\p{L}\p{N}\s/.\\-_]/gu, " "); + const out: string[] = []; + for (const raw of cleaned.split(/\s+/)) { + if (raw.length < 2) continue; + if (hasCjk(raw)) { + for (const seg of segmentCjk(raw)) { + if (seg.length >= 1) out.push(seg); + } + } else { + out.push(stem(raw)); + } + } + return out; } private getSortedTerms(): string[] { diff --git a/test/search-index.test.ts b/test/search-index.test.ts index a635456..f134987 100644 --- a/test/search-index.test.ts +++ b/test/search-index.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect, beforeEach } from "vitest"; import { SearchIndex } from "../src/state/search-index.js"; +import { segmentCjk } from "../src/state/cjk-segmenter.js"; import type { CompressedObservation } from "../src/types.js"; function makeObs( @@ -153,4 +154,65 @@ describe("SearchIndex", () => { expect(results.length).toBe(1); expect(results[0].obsId).toBe("obs_mixed"); }); + + it("segments Chinese (Han) text into words", () => { + index.add( + makeObs({ + id: "obs_zh", + title: "项目记忆存储", + narrative: "我们正在测试中文分词", + concepts: ["项目", "记忆"], + }), + ); + const results = index.search("项目"); + expect(results.length).toBeGreaterThan(0); + const hit = results.find((r) => r.obsId === "obs_zh"); + expect(hit).toBeDefined(); + expect(hit!.score).toBeGreaterThan(0); + }); + + it("segments Japanese (kana + kanji) text into words", () => { + index.add( + makeObs({ + id: "obs_ja", + title: "プロジェクト記憶", + narrative: "日本語の分かち書きをテストしています", + concepts: ["プロジェクト", "記憶"], + }), + ); + const results = index.search("プロジェクト"); + expect(results.length).toBeGreaterThan(0); + const hit = results.find((r) => r.obsId === "obs_ja"); + expect(hit).toBeDefined(); + expect(hit!.score).toBeGreaterThan(0); + }); + + it("segments Korean (Hangul) syllable blocks into words", () => { + index.add( + makeObs({ + id: "obs_ko", + title: "프로젝트 메모리 저장소", + narrative: "한국어 검색을 테스트합니다", + concepts: ["프로젝트", "메모리"], + }), + ); + const results = index.search("메모리"); + expect(results.length).toBeGreaterThan(0); + const hit = results.find((r) => r.obsId === "obs_ko"); + expect(hit).toBeDefined(); + expect(hit!.score).toBeGreaterThan(0); + }); + + it("preserves source order across mixed CJK and non-CJK runs", () => { + expect(segmentCjk("hello 项目 world")).toEqual(["hello", "项目", "world"]); + expect(segmentCjk("abc 메모리 def 项目 ghi")).toEqual([ + "abc", + "메모리", + "def", + "项目", + "ghi", + ]); + expect(segmentCjk("leading 项目")).toEqual(["leading", "项目"]); + expect(segmentCjk("项目 trailing")).toEqual(["项目", "trailing"]); + }); });