Skip to content

Commit 124eb3f

Browse files
feat(parser): improve status updates in layered language derivation (i/total)
1 parent b7afd87 commit 124eb3f

1 file changed

Lines changed: 86 additions & 28 deletions

File tree

src/dylan/parser/language_derivation.py

Lines changed: 86 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,11 @@ def emit_completion_success_only(layer_idx: int, record: LanguageDerivationRecor
662662
for i in range(len(vocab_groups))
663663
]
664664
si = 0
665-
with _layered_console.status("[bold green]Layer 1 (seed) running...[/bold green]"):
665+
n_seed_tasks = len(seed_tasks)
666+
with _layered_console.status(
667+
f"[bold green]Layer 1 (seed)...[/bold green] "
668+
f"[cyan]0/{n_seed_tasks}[/cyan] groups"
669+
) as seed_status:
666670
while si < len(seed_tasks) and not stop:
667671
chunk = seed_tasks[si : si + batch_size]
668672
for row in executor.map(_layered_bfs_seed_group_worker, chunk):
@@ -684,6 +688,10 @@ def emit_completion_success_only(layer_idx: int, record: LanguageDerivationRecor
684688
prefix_words=(),
685689
)
686690
si += len(chunk)
691+
seed_status.update(
692+
f"[bold green]Layer 1 (seed)...[/bold green] "
693+
f"[cyan]{min(si, n_seed_tasks)}/{n_seed_tasks}[/cyan] groups"
694+
)
687695

688696
LanguageDerivation._write_layered_fringe(handles[1][2], frontier)
689697

@@ -726,9 +734,13 @@ def _finalize_prefix_completion(p_end: tuple[str, ...], had_f: bool) -> None:
726734
if max_successful is not None and success_total >= max_successful:
727735
stop = True
728736

737+
n_ext_tasks = len(ext_tasks)
738+
n_frontier_here = len(set(frontier))
729739
with _layered_console.status(
730-
f"[bold green]Layer {depth + 1} (extending {len(frontier)} prefixes)...[/bold green]"
731-
):
740+
f"[bold green]Layer {depth + 1}[/bold green] "
741+
f"(extending {n_frontier_here} prefixes)... "
742+
f"[cyan]0/{n_ext_tasks}[/cyan] tasks"
743+
) as ext_status:
732744
ei = 0
733745
while ei < len(ext_tasks) and not stop:
734746
chunk = ext_tasks[ei : ei + batch_size]
@@ -759,6 +771,11 @@ def _finalize_prefix_completion(p_end: tuple[str, ...], had_f: bool) -> None:
759771
prefix_words=P,
760772
)
761773
ei += len(chunk)
774+
ext_status.update(
775+
f"[bold green]Layer {depth + 1}[/bold green] "
776+
f"(extending {n_frontier_here} prefixes)... "
777+
f"[cyan]{min(ei, n_ext_tasks)}/{n_ext_tasks}[/cyan] tasks"
778+
)
762779
if prev_p is not None:
763780
_finalize_prefix_completion(prev_p, had_failure_for_p)
764781

@@ -778,20 +795,33 @@ def _finalize_prefix_completion(p_end: tuple[str, ...], had_f: bool) -> None:
778795
for P in sorted(set(frontier))
779796
if len(P) == max_len and len(P) >= min_len
780797
]
798+
n_maxlen_tasks = len(maxlen_tasks)
781799
mi = 0
782-
while mi < len(maxlen_tasks) and not stop:
783-
chunk = maxlen_tasks[mi : mi + batch_size]
784-
for rec in executor.map(_layered_bfs_maxlen_completion_worker, chunk):
785-
if stop:
786-
break
787-
if rec is not None:
788-
emit_completion_success_only(max_len, rec)
789-
if max_successful is not None and success_total >= max_successful:
790-
stop = True
791-
mi += len(chunk)
800+
if n_maxlen_tasks > 0:
801+
with _layered_console.status(
802+
f"[bold green]Max-length completions[/bold green] "
803+
f"[cyan]0/{n_maxlen_tasks}[/cyan] prefixes"
804+
) as maxlen_status:
805+
while mi < len(maxlen_tasks) and not stop:
806+
chunk = maxlen_tasks[mi : mi + batch_size]
807+
for rec in executor.map(_layered_bfs_maxlen_completion_worker, chunk):
808+
if stop:
809+
break
810+
if rec is not None:
811+
emit_completion_success_only(max_len, rec)
812+
if max_successful is not None and success_total >= max_successful:
813+
stop = True
814+
mi += len(chunk)
815+
maxlen_status.update(
816+
f"[bold green]Max-length completions[/bold green] "
817+
f"[cyan]{min(mi, n_maxlen_tasks)}/{n_maxlen_tasks}[/cyan] prefixes"
818+
)
792819
else:
793-
with _layered_console.status("[bold green]Layer 1 (seed) running...[/bold green]"):
794-
for group in vocab_groups:
820+
with _layered_console.status(
821+
f"[bold green]Layer 1 (seed)...[/bold green] "
822+
f"[cyan]0/{n_groups}[/cyan] groups"
823+
) as seed_status:
824+
for gi, group in enumerate(vocab_groups, start=1):
795825
if stop:
796826
break
797827
for w in group:
@@ -821,6 +851,10 @@ def _finalize_prefix_completion(p_end: tuple[str, ...], had_f: bool) -> None:
821851
word=group[0],
822852
prefix_words=(),
823853
)
854+
seed_status.update(
855+
f"[bold green]Layer 1 (seed)...[/bold green] "
856+
f"[cyan]{gi}/{n_groups}[/cyan] groups"
857+
)
824858

825859
LanguageDerivation._write_layered_fringe(handles[1][2], frontier)
826860

@@ -838,10 +872,14 @@ def _finalize_prefix_completion(p_end: tuple[str, ...], had_f: bool) -> None:
838872
layer_ext_fail_groups = 0
839873
success_before_depth = success_total
840874

875+
frontier_sorted = sorted(set(frontier))
876+
n_prefixes_here = len(frontier_sorted)
841877
with _layered_console.status(
842-
f"[bold green]Layer {depth + 1} (extending {len(frontier)} prefixes)...[/bold green]"
843-
):
844-
for P in sorted(set(frontier)):
878+
f"[bold green]Layer {depth + 1}[/bold green] "
879+
f"(extending {n_prefixes_here} prefixes)... "
880+
f"[cyan]0/{n_prefixes_here}[/cyan] prefixes"
881+
) as ext_status:
882+
for pi, P in enumerate(frontier_sorted, start=1):
845883
if stop:
846884
break
847885
had_failure = False
@@ -900,6 +938,12 @@ def _finalize_prefix_completion(p_end: tuple[str, ...], had_f: bool) -> None:
900938
if max_successful is not None and success_total >= max_successful:
901939
stop = True
902940

941+
ext_status.update(
942+
f"[bold green]Layer {depth + 1}[/bold green] "
943+
f"(extending {n_prefixes_here} prefixes)... "
944+
f"[cyan]{pi}/{n_prefixes_here}[/cyan] prefixes"
945+
)
946+
903947
frontier = sorted(set(next_frontier))
904948
LanguageDerivation._write_layered_fringe(handles[depth + 1][2], frontier)
905949
new_successes = success_total - success_before_depth
@@ -911,17 +955,31 @@ def _finalize_prefix_completion(p_end: tuple[str, ...], had_f: bool) -> None:
911955
)
912956

913957
if not stop:
914-
for P in sorted(set(frontier)):
915-
if stop:
916-
break
917-
if len(P) == max_len and len(P) >= min_len:
918-
if _replay_prefix(parser, P, speaker=speaker, addressee=addressee):
919-
emit_completion_success_only(
920-
max_len,
921-
_record_completion_from_current_state(parser, sentence=" ".join(P)),
958+
maxlen_prefixes = [
959+
P
960+
for P in sorted(set(frontier))
961+
if len(P) == max_len and len(P) >= min_len
962+
]
963+
n_maxlen_px = len(maxlen_prefixes)
964+
if n_maxlen_px > 0:
965+
with _layered_console.status(
966+
f"[bold green]Max-length completions[/bold green] "
967+
f"[cyan]0/{n_maxlen_px}[/cyan] prefixes"
968+
) as maxlen_status:
969+
for mi, P in enumerate(maxlen_prefixes, start=1):
970+
if stop:
971+
break
972+
if _replay_prefix(parser, P, speaker=speaker, addressee=addressee):
973+
emit_completion_success_only(
974+
max_len,
975+
_record_completion_from_current_state(parser, sentence=" ".join(P)),
976+
)
977+
if max_successful is not None and success_total >= max_successful:
978+
stop = True
979+
maxlen_status.update(
980+
f"[bold green]Max-length completions[/bold green] "
981+
f"[cyan]{mi}/{n_maxlen_px}[/cyan] prefixes"
922982
)
923-
if max_successful is not None and success_total >= max_successful:
924-
stop = True
925983

926984
total_successes = sum(len(ss) for ss in seen_success_by_layer.values())
927985
_layered_console.print(

0 commit comments

Comments
 (0)