Skip to content

Commit 5548349

Browse files
Merge pull request #32 from EfficientContext/locomo-retest
Locomo retest with 5x10
2 parents ed7e6b1 + d4925a0 commit 5548349

2 files changed

Lines changed: 184 additions & 72 deletions

File tree

docs/guides/mem0.md

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,21 @@ python examples/mem0_locomo_example.py
5454
| `LOCOMO_MAX_QA` | `150` | Max QA pairs to evaluate |
5555
| `LOCOMO_MAX_TOKENS` | `32` | Max generation tokens |
5656
| `LOCOMO_NUM_TURNS` | `150` | Multi-turn conversation length |
57-
| `LOCOMO_TOP_K_LIST` | `20,100` | Comma-separated top-k values to benchmark |
57+
| `LOCOMO_TOP_K_LIST` | `20,5x10` | Top-k values to benchmark. Use `N` for standard top-k (e.g. `20`), or `NxM` to retrieve top-N and repeat each M times to simulate long context (e.g. `5x10` retrieves 5 memories, repeats 10x → 50 total context blocks) |
5858

5959
## Results
6060

61-
LoCoMo conv 0, 102 memories, 150 turns:
61+
Aggregate across all 10 LoCoMo conversations, Qwen2.5-7B-Instruct on 2xA6000 (SGLang, tp=2):
62+
63+
| k | mode | ttft | ttft delta | judge |
64+
|---|---|---|---|---|
65+
| 20 | baseline | 0.0566s | - | 0.428 |
66+
| 20 | reorder | 0.0539s | +4.8% | 0.431 |
67+
| 100 | baseline | 0.1012s | - | 0.437 |
68+
| 100 | reorder | 0.0554s | **+45.3%** | 0.420 |
69+
| 5x10 | baseline | 0.1051s | - | 0.418 |
70+
| 5x10 | reorder | 0.0548s | **+47.8%** | 0.414 |
6271

63-
| k | mode | ttft | judge |
64-
|---|---|---|---|
65-
| 20 | baseline | 0.0377s | 0.440 |
66-
| 20 | reorder | 0.0315s | 0.460 |
67-
| 100 | baseline | 0.1012s | 0.437 |
68-
| 100 | reorder | 0.0554s | 0.420 |
6972

7073
## General usage
7174

examples/mem0_locomo_example.py

Lines changed: 173 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
LOCOMO_URL = "https://raw.githubusercontent.com/snap-research/locomo/main/data/locomo10.json"
1414
LOCOMO_CACHE = Path(__file__).resolve().parent.parent / "tests" / ".locomo_cache" / "locomo10.json"
1515

16-
CONV_INDEX = int(os.environ.get("LOCOMO_CONV_INDEX", "0"))
16+
CONV_INDEX = os.environ.get("LOCOMO_CONV_INDEX", "all")
1717
MAX_QA = int(os.environ.get("LOCOMO_MAX_QA", "150"))
1818
MAX_GEN = int(os.environ.get("LOCOMO_MAX_TOKENS", "32"))
1919
NUM_TURNS = int(os.environ.get("LOCOMO_NUM_TURNS", "150"))
20-
TOP_K_LIST = os.environ.get("LOCOMO_TOP_K_LIST", "20,100")
20+
TOP_K_LIST = os.environ.get("LOCOMO_TOP_K_LIST", "20,50,5x10")
2121

2222

2323
async def _stream_ttft(prompt, model, max_tokens=512, request_id=None):
@@ -56,10 +56,38 @@ def run_ttft(prompt, model, max_tokens=512, request_id=None):
5656
return asyncio.run(_stream_ttft(prompt, model, max_tokens, request_id))
5757

5858

59-
def build_prompt(question, context_str):
60-
return (f"Memories:\n{context_str}\n"
61-
f"Based on the memories above, concisely answer the following "
62-
f"question in as few words as possible.\nQuestion: {question}\nAnswer:")
59+
def build_prompt(question, context_str, importance_ranking=None):
60+
prompt = (f"Memories:\n{context_str}\n"
61+
f"Based on the memories above, concisely answer the following "
62+
f"question in as few words as possible.\n")
63+
if importance_ranking:
64+
prompt += (f"Please read the documents in the following importance ranking:\n"
65+
f"{importance_ranking}\n"
66+
f"Prioritize information from higher-ranked documents.\n")
67+
prompt += f"Question: {question}\nAnswer:"
68+
return prompt
69+
70+
71+
def build_importance_ranking(original_ids, reordered_ids):
72+
"""Map original retrieval order to positions in the reordered doc list.
73+
74+
With repeated docs the same doc_id appears multiple times, so we track
75+
the *first* occurrence of each unique doc in the original order and map
76+
it to its first position in the reordered list.
77+
"""
78+
# First occurrence of each doc in reordered list -> its [Doc_N] position
79+
pos = {}
80+
for i, did in enumerate(reordered_ids):
81+
if did not in pos:
82+
pos[did] = i + 1
83+
# Deduplicate original_ids while preserving order
84+
seen = set()
85+
unique_original = []
86+
for did in original_ids:
87+
if did not in seen:
88+
seen.add(did)
89+
unique_original.append(did)
90+
return " > ".join(f"[Doc_{pos[did]}]" for did in unique_original if did in pos)
6391

6492

6593
def llm_judge(question, prediction, ground_truth):
@@ -120,18 +148,20 @@ def strip_thinking(text):
120148

121149
def build_context_str(doc_ids, corpus_map):
122150
parts = []
123-
for did in doc_ids:
151+
for i, did in enumerate(doc_ids):
124152
entry = corpus_map.get(str(did), {})
125153
text = entry.get("text", entry.get("content", f"[doc {did}]"))
126-
parts.append(text)
154+
parts.append(f"[Doc_{i+1}] {text}")
127155
return "\n\n".join(parts)
128156

129157

130158
def run_multi_turn(retriever, user_id, qa_pairs, model, top_k,
131-
use_reorder=False, cp_available=False):
159+
use_reorder=False, cp_available=False, repeat_times=1):
132160
"""Run multi-turn benchmark: baseline vs reorder."""
133161
label = "reorder" if use_reorder else "baseline"
134-
print(f"\n--- {label} ({NUM_TURNS} turns, k={top_k}) ---")
162+
actual_k = top_k * repeat_times if repeat_times > 1 else top_k
163+
suffix = f" (k={top_k}x{repeat_times}={actual_k} docs)" if repeat_times > 1 else f" (k={top_k})"
164+
print(f"\n--- {label} ({NUM_TURNS} turns,{suffix}) ---")
135165

136166
ttfts, prefix_matches, f1s, judges = [], [], [], []
137167

@@ -146,6 +176,11 @@ def run_multi_turn(retriever, user_id, qa_pairs, model, top_k,
146176
cmap = retriever.get_corpus_map()
147177
doc_ids = s[0]["top_k_doc_id"]
148178

179+
# Repeat docs to create long context if requested
180+
if repeat_times > 1:
181+
doc_ids = doc_ids * repeat_times
182+
183+
original_ids = list(doc_ids) # preserve original retrieval order
149184
reordered_ids = doc_ids
150185
req_id = None
151186
server_prefix_len, server_has_prefix, server_node_id = 0, False, -1
@@ -179,10 +214,19 @@ def run_multi_turn(retriever, user_id, qa_pairs, model, top_k,
179214
# Build context string directly from corpus map
180215
context_str = build_context_str(reordered_ids, cmap)
181216

217+
# Build importance ranking — always include so prompt length is equal
218+
# between baseline and reorder (fair TTFT comparison).
219+
# Baseline: natural order [Doc_1] > [Doc_2] > ...
220+
# Reorder: original retrieval order mapped to reordered positions
221+
if use_reorder and reordered_ids != original_ids:
222+
importance_ranking = build_importance_ranking(original_ids, reordered_ids)
223+
else:
224+
importance_ranking = " > ".join(f"[Doc_{i+1}]" for i in range(len(reordered_ids)))
225+
182226
# Build prompt and measure TTFT
183-
prompt = build_prompt(qa["question"], context_str)
227+
prompt = build_prompt(qa["question"], context_str, importance_ranking)
184228
out = run_ttft(prompt, model, MAX_GEN, request_id=req_id)
185-
gt = str(qa["answer"])
229+
gt = str(qa.get("answer", qa.get("answers", qa.get("gold_answer", ""))))
186230

187231
if idx > 0:
188232
ttfts.append(out["ttft"])
@@ -219,6 +263,7 @@ def run_multi_turn(retriever, user_id, qa_pairs, model, top_k,
219263
"prefix": avg(prefix_matches),
220264
"f1": avg(f1s),
221265
"judge": avg(judges),
266+
"repeat": repeat_times,
222267
}
223268
print(f" [{label}] TTFT={stats['ttft']:.4f}s Prefix={stats['prefix']:.1%}"
224269
f" F1={stats['f1']:.3f} Judge={stats['judge']:.3f}")
@@ -283,60 +328,124 @@ def ingest_conversation(conv_data, retriever, user_id):
283328
run_ttft("Hello, world.", model, max_tokens=4)
284329
print("Warmup done.\n")
285330

286-
retriever = Mem0Retriever(config={
287-
"llm": {"provider": "openai", "config": {"model": "gpt-4.1-mini-2025-04-14"}},
288-
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
289-
})
290-
291-
conv_data = all_convs[CONV_INDEX]
292-
qa_pairs = conv_data["qa"][:MAX_QA]
293-
conv = conv_data["conversation"]
294-
print(f"\n{'='*70}")
295-
print(f"CONV {CONV_INDEX}: {conv['speaker_a']} & {conv['speaker_b']}, {len(qa_pairs)} QA pairs")
296-
print(f"{'='*70}")
297-
298-
user_id = f"locomo_{CONV_INDEX}_{uuid.uuid4().hex[:6]}"
299-
n_memories = ingest_conversation(conv_data, retriever, user_id)
300-
top_k_values = [int(k) for k in TOP_K_LIST.split(",")]
331+
# Parse TOP_K_LIST: supports "20", "50", or "5x10" (k=5, repeat 10 times)
332+
top_k_configs = []
333+
for entry in TOP_K_LIST.split(","):
334+
entry = entry.strip()
335+
if "x" in entry:
336+
k_str, r_str = entry.split("x", 1)
337+
top_k_configs.append((int(k_str), int(r_str)))
338+
else:
339+
top_k_configs.append((int(entry), 1))
340+
341+
# Determine which conversations to run
342+
if CONV_INDEX == "all":
343+
conv_indices = list(range(len(all_convs)))
344+
else:
345+
conv_indices = [int(CONV_INDEX)]
346+
347+
grand_rows = [] # aggregate across all conversations
348+
349+
for ci in conv_indices:
350+
# Flush SGLang's radix cache between conversations to avoid pressure buildup
351+
try:
352+
requests.post(f"{INFERENCE_URL}/flush_cache", timeout=5)
353+
except Exception:
354+
pass
301355

302-
try:
303-
all_rows = []
304-
for top_k in top_k_values:
305-
print(f"\n## top_k={top_k}")
306-
results = {}
307-
for use_reorder in [True, False]:
308-
cp_reset() # fresh tree for each mode
309-
stats = run_multi_turn(
310-
retriever, user_id, qa_pairs, model, top_k,
311-
use_reorder=use_reorder, cp_available=cp_available)
312-
results[stats["label"]] = stats
313-
314-
base_ttft = results["baseline"]["ttft"]
315-
316-
for name in ["baseline", "reorder"]:
317-
s = results[name]
318-
delta = (base_ttft - s["ttft"]) / base_ttft * 100 if base_ttft else 0
319-
all_rows.append({
320-
"k": top_k,
321-
"mode": name,
322-
"ttft": f"{s['ttft']:.4f}s",
323-
"ttft_delta": f"{delta:+.1f}%" if name != "baseline" else "-",
324-
"prefix": f"{s['prefix']:.1%}",
325-
"f1": f"{s['f1']:.3f}",
326-
"judge": f"{s['judge']:.3f}",
327-
})
328-
329-
# Summary table
356+
conv_data = all_convs[ci]
357+
qa_pairs = conv_data["qa"][:MAX_QA]
358+
conv = conv_data["conversation"]
330359
print(f"\n{'='*70}")
331-
print(f"RESULTS (conv={CONV_INDEX}, memories={n_memories}, turns={min(NUM_TURNS, len(qa_pairs))})")
360+
print(f"CONV {ci}: {conv['speaker_a']} & {conv['speaker_b']}, {len(qa_pairs)} QA pairs")
332361
print(f"{'='*70}")
333-
print(pd.DataFrame(all_rows).to_string(index=False))
334362

335-
finally:
363+
retriever = Mem0Retriever(config={
364+
"llm": {"provider": "openai", "config": {"model": "gpt-4.1-mini-2025-04-14"}},
365+
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
366+
})
367+
368+
user_id = f"locomo_{ci}_{uuid.uuid4().hex[:6]}"
369+
n_memories = ingest_conversation(conv_data, retriever, user_id)
370+
336371
try:
337-
retriever.delete_all_memories(user_id=user_id)
338-
print(f"\nCleaned up memories for {user_id}")
339-
except Exception as e:
340-
print(f"\nCleanup warning: {e}")
341-
del retriever
342-
import gc; gc.collect()
372+
conv_rows = []
373+
for top_k, repeat_times in top_k_configs:
374+
label = f"top_k={top_k}" + (f"x{repeat_times}" if repeat_times > 1 else "")
375+
print(f"\n## {label}")
376+
results = {}
377+
for use_reorder in [False, True]:
378+
cp_reset() # fresh tree for each mode
379+
stats = run_multi_turn(
380+
retriever, user_id, qa_pairs, model, top_k,
381+
use_reorder=use_reorder, cp_available=cp_available,
382+
repeat_times=repeat_times)
383+
results[stats["label"]] = stats
384+
385+
base_ttft = results["baseline"]["ttft"]
386+
387+
k_label = f"{top_k}x{repeat_times}" if repeat_times > 1 else str(top_k)
388+
for name in ["baseline", "reorder"]:
389+
s = results[name]
390+
delta = (base_ttft - s["ttft"]) / base_ttft * 100 if base_ttft else 0
391+
row = {
392+
"conv": ci,
393+
"k": k_label,
394+
"mode": name,
395+
"ttft": s["ttft"],
396+
"ttft_delta": delta if name != "baseline" else 0,
397+
"prefix": s["prefix"],
398+
"f1": s["f1"],
399+
"judge": s["judge"],
400+
}
401+
conv_rows.append(row)
402+
grand_rows.append(row)
403+
404+
# Per-conversation summary
405+
print(f"\n{'='*70}")
406+
print(f"RESULTS (conv={ci}, memories={n_memories}, turns={min(NUM_TURNS, len(qa_pairs))})")
407+
print(f"{'='*70}")
408+
df = pd.DataFrame(conv_rows)
409+
df_display = df.copy()
410+
df_display["ttft"] = df_display["ttft"].map(lambda x: f"{x:.4f}s")
411+
df_display["ttft_delta"] = df.apply(
412+
lambda r: f"{r['ttft_delta']:+.1f}%" if r["mode"] != "baseline" else "-", axis=1)
413+
df_display["prefix"] = df_display["prefix"].map(lambda x: f"{x:.1%}")
414+
df_display["f1"] = df_display["f1"].map(lambda x: f"{x:.3f}")
415+
df_display["judge"] = df_display["judge"].map(lambda x: f"{x:.3f}")
416+
print(df_display.drop(columns=["conv"]).to_string(index=False))
417+
418+
finally:
419+
try:
420+
retriever.delete_all_memories(user_id=user_id)
421+
print(f"\nCleaned up memories for {user_id}")
422+
except Exception as e:
423+
print(f"\nCleanup warning: {e}")
424+
del retriever
425+
import gc; gc.collect()
426+
427+
# Grand aggregate table across all conversations
428+
if len(conv_indices) > 1:
429+
print(f"\n{'='*70}")
430+
print(f"AGGREGATE RESULTS ({len(conv_indices)} conversations)")
431+
print(f"{'='*70}")
432+
gdf = pd.DataFrame(grand_rows)
433+
agg = gdf.groupby(["k", "mode"]).agg(
434+
ttft=("ttft", "mean"),
435+
prefix=("prefix", "mean"),
436+
f1=("f1", "mean"),
437+
judge=("judge", "mean"),
438+
).reset_index()
439+
# Compute delta from baseline per k
440+
for k_val in agg["k"].unique():
441+
base = agg.loc[(agg["k"] == k_val) & (agg["mode"] == "baseline"), "ttft"].values[0]
442+
agg.loc[agg["k"] == k_val, "ttft_delta"] = agg.loc[agg["k"] == k_val, "ttft"].apply(
443+
lambda x: (base - x) / base * 100 if base else 0)
444+
agg_display = agg.copy()
445+
agg_display["ttft"] = agg_display["ttft"].map(lambda x: f"{x:.4f}s")
446+
agg_display["ttft_delta"] = agg.apply(
447+
lambda r: f"{r['ttft_delta']:+.1f}%" if r["mode"] != "baseline" else "-", axis=1)
448+
agg_display["prefix"] = agg_display["prefix"].map(lambda x: f"{x:.1%}")
449+
agg_display["f1"] = agg_display["f1"].map(lambda x: f"{x:.3f}")
450+
agg_display["judge"] = agg_display["judge"].map(lambda x: f"{x:.3f}")
451+
print(agg_display[["k", "mode", "ttft", "ttft_delta", "prefix", "f1", "judge"]].to_string(index=False))

0 commit comments

Comments
 (0)