1313LOCOMO_URL = "https://raw.githubusercontent.com/snap-research/locomo/main/data/locomo10.json"
1414LOCOMO_CACHE = Path (__file__ ).resolve ().parent .parent / "tests" / ".locomo_cache" / "locomo10.json"
1515
16- CONV_INDEX = int ( os .environ .get ("LOCOMO_CONV_INDEX" , "0" ) )
16+ CONV_INDEX = os .environ .get ("LOCOMO_CONV_INDEX" , "all" )
1717MAX_QA = int (os .environ .get ("LOCOMO_MAX_QA" , "150" ))
1818MAX_GEN = int (os .environ .get ("LOCOMO_MAX_TOKENS" , "32" ))
1919NUM_TURNS = int (os .environ .get ("LOCOMO_NUM_TURNS" , "150" ))
20- TOP_K_LIST = os .environ .get ("LOCOMO_TOP_K_LIST" , "20,100 " )
20+ TOP_K_LIST = os .environ .get ("LOCOMO_TOP_K_LIST" , "20,50,5x10 " )
2121
2222
2323async def _stream_ttft (prompt , model , max_tokens = 512 , request_id = None ):
@@ -56,10 +56,38 @@ def run_ttft(prompt, model, max_tokens=512, request_id=None):
5656 return asyncio .run (_stream_ttft (prompt , model , max_tokens , request_id ))
5757
5858
59- def build_prompt (question , context_str ):
60- return (f"Memories:\n { context_str } \n "
61- f"Based on the memories above, concisely answer the following "
62- f"question in as few words as possible.\n Question: { question } \n Answer:" )
59+ def build_prompt (question , context_str , importance_ranking = None ):
60+ prompt = (f"Memories:\n { context_str } \n "
61+ f"Based on the memories above, concisely answer the following "
62+ f"question in as few words as possible.\n " )
63+ if importance_ranking :
64+ prompt += (f"Please read the documents in the following importance ranking:\n "
65+ f"{ importance_ranking } \n "
66+ f"Prioritize information from higher-ranked documents.\n " )
67+ prompt += f"Question: { question } \n Answer:"
68+ return prompt
69+
70+
71+ def build_importance_ranking (original_ids , reordered_ids ):
72+ """Map original retrieval order to positions in the reordered doc list.
73+
74+ With repeated docs the same doc_id appears multiple times, so we track
75+ the *first* occurrence of each unique doc in the original order and map
76+ it to its first position in the reordered list.
77+ """
78+ # First occurrence of each doc in reordered list -> its [Doc_N] position
79+ pos = {}
80+ for i , did in enumerate (reordered_ids ):
81+ if did not in pos :
82+ pos [did ] = i + 1
83+ # Deduplicate original_ids while preserving order
84+ seen = set ()
85+ unique_original = []
86+ for did in original_ids :
87+ if did not in seen :
88+ seen .add (did )
89+ unique_original .append (did )
90+ return " > " .join (f"[Doc_{ pos [did ]} ]" for did in unique_original if did in pos )
6391
6492
6593def llm_judge (question , prediction , ground_truth ):
@@ -120,18 +148,20 @@ def strip_thinking(text):
120148
121149def build_context_str (doc_ids , corpus_map ):
122150 parts = []
123- for did in doc_ids :
151+ for i , did in enumerate ( doc_ids ) :
124152 entry = corpus_map .get (str (did ), {})
125153 text = entry .get ("text" , entry .get ("content" , f"[doc { did } ]" ))
126- parts .append (text )
154+ parts .append (f"[Doc_ { i + 1 } ] { text } " )
127155 return "\n \n " .join (parts )
128156
129157
130158def run_multi_turn (retriever , user_id , qa_pairs , model , top_k ,
131- use_reorder = False , cp_available = False ):
159+ use_reorder = False , cp_available = False , repeat_times = 1 ):
132160 """Run multi-turn benchmark: baseline vs reorder."""
133161 label = "reorder" if use_reorder else "baseline"
134- print (f"\n --- { label } ({ NUM_TURNS } turns, k={ top_k } ) ---" )
162+ actual_k = top_k * repeat_times if repeat_times > 1 else top_k
163+ suffix = f" (k={ top_k } x{ repeat_times } ={ actual_k } docs)" if repeat_times > 1 else f" (k={ top_k } )"
164+ print (f"\n --- { label } ({ NUM_TURNS } turns,{ suffix } ) ---" )
135165
136166 ttfts , prefix_matches , f1s , judges = [], [], [], []
137167
@@ -146,6 +176,11 @@ def run_multi_turn(retriever, user_id, qa_pairs, model, top_k,
146176 cmap = retriever .get_corpus_map ()
147177 doc_ids = s [0 ]["top_k_doc_id" ]
148178
179+ # Repeat docs to create long context if requested
180+ if repeat_times > 1 :
181+ doc_ids = doc_ids * repeat_times
182+
183+ original_ids = list (doc_ids ) # preserve original retrieval order
149184 reordered_ids = doc_ids
150185 req_id = None
151186 server_prefix_len , server_has_prefix , server_node_id = 0 , False , - 1
@@ -179,10 +214,19 @@ def run_multi_turn(retriever, user_id, qa_pairs, model, top_k,
179214 # Build context string directly from corpus map
180215 context_str = build_context_str (reordered_ids , cmap )
181216
217+ # Build importance ranking — always include so prompt length is equal
218+ # between baseline and reorder (fair TTFT comparison).
219+ # Baseline: natural order [Doc_1] > [Doc_2] > ...
220+ # Reorder: original retrieval order mapped to reordered positions
221+ if use_reorder and reordered_ids != original_ids :
222+ importance_ranking = build_importance_ranking (original_ids , reordered_ids )
223+ else :
224+ importance_ranking = " > " .join (f"[Doc_{ i + 1 } ]" for i in range (len (reordered_ids )))
225+
182226 # Build prompt and measure TTFT
183- prompt = build_prompt (qa ["question" ], context_str )
227+ prompt = build_prompt (qa ["question" ], context_str , importance_ranking )
184228 out = run_ttft (prompt , model , MAX_GEN , request_id = req_id )
185- gt = str (qa [ "answer" ] )
229+ gt = str (qa . get ( "answer" , qa . get ( "answers" , qa . get ( "gold_answer" , "" ))) )
186230
187231 if idx > 0 :
188232 ttfts .append (out ["ttft" ])
@@ -219,6 +263,7 @@ def run_multi_turn(retriever, user_id, qa_pairs, model, top_k,
219263 "prefix" : avg (prefix_matches ),
220264 "f1" : avg (f1s ),
221265 "judge" : avg (judges ),
266+ "repeat" : repeat_times ,
222267 }
223268 print (f" [{ label } ] TTFT={ stats ['ttft' ]:.4f} s Prefix={ stats ['prefix' ]:.1%} "
224269 f" F1={ stats ['f1' ]:.3f} Judge={ stats ['judge' ]:.3f} " )
@@ -283,60 +328,124 @@ def ingest_conversation(conv_data, retriever, user_id):
283328 run_ttft ("Hello, world." , model , max_tokens = 4 )
284329 print ("Warmup done.\n " )
285330
286- retriever = Mem0Retriever (config = {
287- "llm" : {"provider" : "openai" , "config" : {"model" : "gpt-4.1-mini-2025-04-14" }},
288- "embedder" : {"provider" : "openai" , "config" : {"model" : "text-embedding-3-small" }},
289- })
290-
291- conv_data = all_convs [CONV_INDEX ]
292- qa_pairs = conv_data ["qa" ][:MAX_QA ]
293- conv = conv_data ["conversation" ]
294- print (f"\n { '=' * 70 } " )
295- print (f"CONV { CONV_INDEX } : { conv ['speaker_a' ]} & { conv ['speaker_b' ]} , { len (qa_pairs )} QA pairs" )
296- print (f"{ '=' * 70 } " )
297-
298- user_id = f"locomo_{ CONV_INDEX } _{ uuid .uuid4 ().hex [:6 ]} "
299- n_memories = ingest_conversation (conv_data , retriever , user_id )
300- top_k_values = [int (k ) for k in TOP_K_LIST .split ("," )]
331+ # Parse TOP_K_LIST: supports "20", "50", or "5x10" (k=5, repeat 10 times)
332+ top_k_configs = []
333+ for entry in TOP_K_LIST .split ("," ):
334+ entry = entry .strip ()
335+ if "x" in entry :
336+ k_str , r_str = entry .split ("x" , 1 )
337+ top_k_configs .append ((int (k_str ), int (r_str )))
338+ else :
339+ top_k_configs .append ((int (entry ), 1 ))
340+
341+ # Determine which conversations to run
342+ if CONV_INDEX == "all" :
343+ conv_indices = list (range (len (all_convs )))
344+ else :
345+ conv_indices = [int (CONV_INDEX )]
346+
347+ grand_rows = [] # aggregate across all conversations
348+
349+ for ci in conv_indices :
350+ # Flush SGLang's radix cache between conversations to avoid pressure buildup
351+ try :
352+ requests .post (f"{ INFERENCE_URL } /flush_cache" , timeout = 5 )
353+ except Exception :
354+ pass
301355
302- try :
303- all_rows = []
304- for top_k in top_k_values :
305- print (f"\n ## top_k={ top_k } " )
306- results = {}
307- for use_reorder in [True , False ]:
308- cp_reset () # fresh tree for each mode
309- stats = run_multi_turn (
310- retriever , user_id , qa_pairs , model , top_k ,
311- use_reorder = use_reorder , cp_available = cp_available )
312- results [stats ["label" ]] = stats
313-
314- base_ttft = results ["baseline" ]["ttft" ]
315-
316- for name in ["baseline" , "reorder" ]:
317- s = results [name ]
318- delta = (base_ttft - s ["ttft" ]) / base_ttft * 100 if base_ttft else 0
319- all_rows .append ({
320- "k" : top_k ,
321- "mode" : name ,
322- "ttft" : f"{ s ['ttft' ]:.4f} s" ,
323- "ttft_delta" : f"{ delta :+.1f} %" if name != "baseline" else "-" ,
324- "prefix" : f"{ s ['prefix' ]:.1%} " ,
325- "f1" : f"{ s ['f1' ]:.3f} " ,
326- "judge" : f"{ s ['judge' ]:.3f} " ,
327- })
328-
329- # Summary table
356+ conv_data = all_convs [ci ]
357+ qa_pairs = conv_data ["qa" ][:MAX_QA ]
358+ conv = conv_data ["conversation" ]
330359 print (f"\n { '=' * 70 } " )
331- print (f"RESULTS (conv= { CONV_INDEX } , memories= { n_memories } , turns= { min ( NUM_TURNS , len (qa_pairs )) } ) " )
360+ print (f"CONV { ci } : { conv [ 'speaker_a' ] } & { conv [ 'speaker_b' ] } , { len (qa_pairs )} QA pairs " )
332361 print (f"{ '=' * 70 } " )
333- print (pd .DataFrame (all_rows ).to_string (index = False ))
334362
335- finally :
363+ retriever = Mem0Retriever (config = {
364+ "llm" : {"provider" : "openai" , "config" : {"model" : "gpt-4.1-mini-2025-04-14" }},
365+ "embedder" : {"provider" : "openai" , "config" : {"model" : "text-embedding-3-small" }},
366+ })
367+
368+ user_id = f"locomo_{ ci } _{ uuid .uuid4 ().hex [:6 ]} "
369+ n_memories = ingest_conversation (conv_data , retriever , user_id )
370+
336371 try :
337- retriever .delete_all_memories (user_id = user_id )
338- print (f"\n Cleaned up memories for { user_id } " )
339- except Exception as e :
340- print (f"\n Cleanup warning: { e } " )
341- del retriever
342- import gc ; gc .collect ()
372+ conv_rows = []
373+ for top_k , repeat_times in top_k_configs :
374+ label = f"top_k={ top_k } " + (f"x{ repeat_times } " if repeat_times > 1 else "" )
375+ print (f"\n ## { label } " )
376+ results = {}
377+ for use_reorder in [False , True ]:
378+ cp_reset () # fresh tree for each mode
379+ stats = run_multi_turn (
380+ retriever , user_id , qa_pairs , model , top_k ,
381+ use_reorder = use_reorder , cp_available = cp_available ,
382+ repeat_times = repeat_times )
383+ results [stats ["label" ]] = stats
384+
385+ base_ttft = results ["baseline" ]["ttft" ]
386+
387+ k_label = f"{ top_k } x{ repeat_times } " if repeat_times > 1 else str (top_k )
388+ for name in ["baseline" , "reorder" ]:
389+ s = results [name ]
390+ delta = (base_ttft - s ["ttft" ]) / base_ttft * 100 if base_ttft else 0
391+ row = {
392+ "conv" : ci ,
393+ "k" : k_label ,
394+ "mode" : name ,
395+ "ttft" : s ["ttft" ],
396+ "ttft_delta" : delta if name != "baseline" else 0 ,
397+ "prefix" : s ["prefix" ],
398+ "f1" : s ["f1" ],
399+ "judge" : s ["judge" ],
400+ }
401+ conv_rows .append (row )
402+ grand_rows .append (row )
403+
404+ # Per-conversation summary
405+ print (f"\n { '=' * 70 } " )
406+ print (f"RESULTS (conv={ ci } , memories={ n_memories } , turns={ min (NUM_TURNS , len (qa_pairs ))} )" )
407+ print (f"{ '=' * 70 } " )
408+ df = pd .DataFrame (conv_rows )
409+ df_display = df .copy ()
410+ df_display ["ttft" ] = df_display ["ttft" ].map (lambda x : f"{ x :.4f} s" )
411+ df_display ["ttft_delta" ] = df .apply (
412+ lambda r : f"{ r ['ttft_delta' ]:+.1f} %" if r ["mode" ] != "baseline" else "-" , axis = 1 )
413+ df_display ["prefix" ] = df_display ["prefix" ].map (lambda x : f"{ x :.1%} " )
414+ df_display ["f1" ] = df_display ["f1" ].map (lambda x : f"{ x :.3f} " )
415+ df_display ["judge" ] = df_display ["judge" ].map (lambda x : f"{ x :.3f} " )
416+ print (df_display .drop (columns = ["conv" ]).to_string (index = False ))
417+
418+ finally :
419+ try :
420+ retriever .delete_all_memories (user_id = user_id )
421+ print (f"\n Cleaned up memories for { user_id } " )
422+ except Exception as e :
423+ print (f"\n Cleanup warning: { e } " )
424+ del retriever
425+ import gc ; gc .collect ()
426+
427+ # Grand aggregate table across all conversations
428+ if len (conv_indices ) > 1 :
429+ print (f"\n { '=' * 70 } " )
430+ print (f"AGGREGATE RESULTS ({ len (conv_indices )} conversations)" )
431+ print (f"{ '=' * 70 } " )
432+ gdf = pd .DataFrame (grand_rows )
433+ agg = gdf .groupby (["k" , "mode" ]).agg (
434+ ttft = ("ttft" , "mean" ),
435+ prefix = ("prefix" , "mean" ),
436+ f1 = ("f1" , "mean" ),
437+ judge = ("judge" , "mean" ),
438+ ).reset_index ()
439+ # Compute delta from baseline per k
440+ for k_val in agg ["k" ].unique ():
441+ base = agg .loc [(agg ["k" ] == k_val ) & (agg ["mode" ] == "baseline" ), "ttft" ].values [0 ]
442+ agg .loc [agg ["k" ] == k_val , "ttft_delta" ] = agg .loc [agg ["k" ] == k_val , "ttft" ].apply (
443+ lambda x : (base - x ) / base * 100 if base else 0 )
444+ agg_display = agg .copy ()
445+ agg_display ["ttft" ] = agg_display ["ttft" ].map (lambda x : f"{ x :.4f} s" )
446+ agg_display ["ttft_delta" ] = agg .apply (
447+ lambda r : f"{ r ['ttft_delta' ]:+.1f} %" if r ["mode" ] != "baseline" else "-" , axis = 1 )
448+ agg_display ["prefix" ] = agg_display ["prefix" ].map (lambda x : f"{ x :.1%} " )
449+ agg_display ["f1" ] = agg_display ["f1" ].map (lambda x : f"{ x :.3f} " )
450+ agg_display ["judge" ] = agg_display ["judge" ].map (lambda x : f"{ x :.3f} " )
451+ print (agg_display [["k" , "mode" , "ttft" , "ttft_delta" , "prefix" , "f1" , "judge" ]].to_string (index = False ))
0 commit comments